NumPy: squeeze() to remove dimensions of size 1 from an array
In NumPy, to remove dimensions of size 1 from an array (ndarray), use the np.squeeze() function. This is also available as a method of ndarray.
Use np.reshape() to convert an array to any shape, and np.newaxis or np.expand_dims() to add a new dimension of size 1. For details, see the following articles.
- NumPy: reshape() to change the shape of an array
- NumPy: Add new dimensions to an array (np.newaxis, np.expand_dims)
The NumPy version used in this article is as follows. Note that functionality may vary between versions.
import numpy as np
print(np.__version__)
# 1.26.1
Basic usage of the np.squeeze() function
Specifying ndarray as the first argument in np.squeeze() returns an ndarray with all dimensions of size 1 removed.
a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]
print(a.shape)
# (1, 2, 1, 3, 1)
a_s = np.squeeze(a)
print(a_s)
# [[0 1 2]
# [3 4 5]]
print(a_s.shape)
# (2, 3)
np.squeeze() returns a view of the original ndarray. Since the original object and the view object share memory, changing a value in one could affect the other.
print(np.shares_memory(a, a_s))
# True
Use copy() to create a copy.
a_s_copy = np.squeeze(a).copy()
print(np.shares_memory(a, a_s_copy))
# False
For more information about views and copies in NumPy, refer to the following article.
Specify dimensions to remove: axis
By default, np.squeeze() removes all dimensions of size 1.
You can specify the index of the dimension to be removed in the second argument axis of np.squeeze(). Any dimensions not specified in axis will remain unchanged.
a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]
print(a.shape)
# (1, 2, 1, 3, 1)
print(np.squeeze(a, 0))
# [[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]
print(np.squeeze(a, 0).shape)
# (2, 1, 3, 1)
Specifying a dimension that is not of size 1, or a non-existent dimension, results in an error.
# print(np.squeeze(a, 1))
# ValueError: cannot select an axis to squeeze out which has size not equal to one
# print(np.squeeze(a, 5))
# AxisError: axis 5 is out of bounds for array of dimension 5
Negative values for axis enable reverse-order specification, with -1 denoting the last dimension.
print(np.squeeze(a, -1))
# [[[[0 1 2]]
#
# [[3 4 5]]]]
print(np.squeeze(a, -1).shape)
# (1, 2, 1, 3)
print(np.squeeze(a, -3))
# [[[[0]
# [1]
# [2]]
#
# [[3]
# [4]
# [5]]]]
print(np.squeeze(a, -3).shape)
# (1, 2, 3, 1)
Multiple dimensions can be specified with a tuple. An error occurs if it includes a dimension that is not of size 1 or does not exist.
print(np.squeeze(a, (0, -1)))
# [[[0 1 2]]
#
# [[3 4 5]]]
print(np.squeeze(a, (0, -1)).shape)
# (2, 1, 3)
# print(np.squeeze(a, (0, 1)))
# ValueError: cannot select an axis to squeeze out which has size not equal to one
The squeeze() method of ndarray
ndarray has a squeeze() method, which works like np.squeeze(). Here, you use axis as the first argument.
a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]
print(a.shape)
# (1, 2, 1, 3, 1)
print(a.squeeze())
# [[0 1 2]
# [3 4 5]]
print(a.squeeze().shape)
# (2, 3)
print(a.squeeze((0, -1)))
# [[[0 1 2]]
#
# [[3 4 5]]]
print(a.squeeze((0, -1)).shape)
# (2, 1, 3)
Like np.squeeze(), this method returns a view, not changing the shape of the original object.
a_s = a.squeeze()
print(a_s)
# [[0 1 2]
# [3 4 5]]
print(np.shares_memory(a, a_s))
# True
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]