NumPy配列ndarrayのサイズ1の次元を削除するnp.squeeze()
NumPy配列ndarray
のサイズが1
の次元をまとめて削除するにはnp.squeeze()
を使う。ndarray
のメソッドとしても提供されている。
任意の形状shape
に変換したい場合はnp.reshape()
、サイズ1
の新たな次元を追加したい場合はnp.newaxis
やnp.expand_dims()
を使う。以下の記事を参照。
本記事のサンプルコードのNumPyのバージョンは以下の通り。バージョンによって仕様が異なる可能性があるので注意。
import numpy as np
print(np.__version__)
# 1.26.1
numpy.squeeze()の基本的な使い方
np.squeeze()
の第一引数に配列ndarray
を指定すると、サイズが1
の次元がすべて削除された形状shape
のndarray
が返される。
a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]
print(a.shape)
# (1, 2, 1, 3, 1)
a_s = np.squeeze(a)
print(a_s)
# [[0 1 2]
# [3 4 5]]
print(a_s.shape)
# (2, 3)
np.squeeze()
が返すのは元のndarray
のビュー。元のオブジェクトとビューオブジェクトはメモリを共有するので、一方の要素を変更すると他方の要素も変更される。
print(np.shares_memory(a, a_s))
# True
コピーを生成したい場合はcopy()
を使う。
a_s_copy = np.squeeze(a).copy()
print(np.shares_memory(a, a_s_copy))
# False
NumPyにおけるビューとコピーについては以下の記事を参照。
削除対象とする次元を指定: 引数axis
上の例のように、デフォルトではサイズが1
の次元がすべて削除される。
np.squeeze()
の第二引数axis
に削除対象となる次元のインデックスを指定できる。指定していない次元はサイズが1
であっても削除されない。
a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]
print(a.shape)
# (1, 2, 1, 3, 1)
print(np.squeeze(a, 0))
# [[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]
print(np.squeeze(a, 0).shape)
# (2, 1, 3, 1)
サイズが1
ではない次元を指定したり、存在しない次元を指定するとエラー。
# print(np.squeeze(a, 1))
# ValueError: cannot select an axis to squeeze out which has size not equal to one
# print(np.squeeze(a, 5))
# AxisError: axis 5 is out of bounds for array of dimension 5
負の値で指定することも可能。-1
が最後の次元に対応し、後ろからの位置で指定できる。
print(np.squeeze(a, -1))
# [[[[0 1 2]]
#
# [[3 4 5]]]]
print(np.squeeze(a, -1).shape)
# (1, 2, 1, 3)
print(np.squeeze(a, -3))
# [[[[0]
# [1]
# [2]]
#
# [[3]
# [4]
# [5]]]]
print(np.squeeze(a, -3).shape)
# (1, 2, 3, 1)
タプルで複数の次元を指定できる。サイズが1
ではない次元や存在しない次元が含まれているとエラー。
print(np.squeeze(a, (0, -1)))
# [[[0 1 2]]
#
# [[3 4 5]]]
print(np.squeeze(a, (0, -1)).shape)
# (2, 1, 3)
# print(np.squeeze(a, (0, 1)))
# ValueError: cannot select an axis to squeeze out which has size not equal to one
numpy.ndarray.squeeze()の場合
配列ndarray
のメソッドとしてもsqueeze()
が提供されている。
使い方はnp.squeeze()
と同じ。第一引数がaxis
となる。
a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]
print(a.shape)
# (1, 2, 1, 3, 1)
print(a.squeeze())
# [[0 1 2]
# [3 4 5]]
print(a.squeeze().shape)
# (2, 3)
print(a.squeeze((0, -1)))
# [[[0 1 2]]
#
# [[3 4 5]]]
print(a.squeeze((0, -1)).shape)
# (2, 1, 3)
np.squeeze()
と同じくビューを返す。元のオブジェクトの形状shape
はそのまま。inplaceな処理ではない。
a_s = a.squeeze()
print(a_s)
# [[0 1 2]
# [3 4 5]]
print(np.shares_memory(a, a_s))
# True
print(a)
# [[[[[0]
# [1]
# [2]]]
#
#
# [[[3]
# [4]
# [5]]]]]