note.nkmk.me

NumPy: Remove dimensions of size 1 from ndarray (np.squeeze)

Posted: 2020-09-24 / Tags: Python, NumPy

You can use numpy.squeeze() to remove all dimensions of size 1 from the NumPy array ndarray. squeeze() is also provided as a method of ndarray.

This post describes the following:

  • Basic usage of numpy.squeeze()
  • Specify the dimension to be deleted: axis
  • For numpy.ndarray.squeeze()

Use numpy.reshape() to convert to any shape, and numpy.newaxis, numpy.expand_dims() to add a new dimension of size 1. See the following post for details.

Sponsored Link

Basic usage of numpy.squeeze()

Specifying numpy.ndarray as the first argument of numpy.squeeze() returns numpy.ndarray with all dimensions of size 1 removed.

import numpy as np

a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a.shape)
# (1, 2, 1, 3, 1)

a_s = np.squeeze(a)
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(a_s.shape)
# (2, 3)

numpy.squeeze() returns a view of the original numpy.ndarray. The original object and the view object share memory, so changing one element changes the other.

print(np.shares_memory(a, a_s))
# True

If you want to make a copy, use copy().

a_s_copy = np.squeeze(a).copy()

print(np.shares_memory(a, a_s_copy))
# False

See the following post for views and copies in NumPy.

Specify the dimension to be deleted: axis

By default, all dimensions with size 1 are removed, as in the example above.

You can specify the index of the dimension to be removed in the second argument axis of numpy.squeeze(). Dimensions that are not the specified index are not removed.

print(a.shape)
# (1, 2, 1, 3, 1)
print(np.squeeze(a, 0))
# [[[[0]
#    [1]
#    [2]]]
# 
# 
#  [[[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, 0).shape)
# (2, 1, 3, 1)

An error will occur if you specify a dimension whose size is not 1 or a dimension that does not exist.

# print(np.squeeze(a, 1))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

# print(np.squeeze(a, 5))
# AxisError: axis 5 is out of bounds for array of dimension 5

axis can also be specified as a negative value. -1 corresponds to the last dimension and can be specified by the position from the back.

print(np.squeeze(a, -1))
# [[[[0 1 2]]
# 
#   [[3 4 5]]]]

print(np.squeeze(a, -1).shape)
# (1, 2, 1, 3)

print(np.squeeze(a, -3))
# [[[[0]
#    [1]
#    [2]]
# 
#   [[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, -3).shape)
# (1, 2, 3, 1)

You can specify multiple dimensions with tuples. An error occurs if a dimension whose size is not 1 or does not exist is included.

print(np.squeeze(a, (0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(np.squeeze(a, (0, -1)).shape)
# (2, 1, 3)

# print(np.squeeze(a, (0, 1)))
# ValueError: cannot select an axis to squeeze out which has size not equal to one
Sponsored Link

For numpy.ndarray.squeeze()

squeeze() is also provided as a method of numpy.ndarray.

Usage is the same as numpy.squeeze(). The first argument is axis.

print(a.squeeze())
# [[0 1 2]
#  [3 4 5]]

print(a.squeeze().shape)
# (2, 3)

print(a.squeeze((0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a.squeeze((0, -1)).shape)
# (2, 1, 3)

squeeze() method also returns a view like numpy.squeeze(). The original object remains the same.

a_s = a.squeeze()
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(np.shares_memory(a, a_s))
# True

print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]
Sponsored Link
Share

Related Categories

Related Posts