NumPy配列ndarrayのサイズ1の次元を削除するnp.squeeze()

Modified: | Tags: Python, NumPy

NumPy配列ndarrayのサイズが1の次元をまとめて削除するにはnp.squeeze()を使う。ndarrayのメソッドとしても提供されている。

任意の形状shapeに変換したい場合はnp.reshape()、サイズ1の新たな次元を追加したい場合はnp.newaxisnp.expand_dims()を使う。以下の記事を参照。

本記事のサンプルコードのNumPyのバージョンは以下の通り。バージョンによって仕様が異なる可能性があるので注意。

import numpy as np

print(np.__version__)
# 1.26.1

numpy.squeeze()の基本的な使い方

np.squeeze()の第一引数に配列ndarrayを指定すると、サイズが1の次元がすべて削除された形状shapendarrayが返される。

a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a.shape)
# (1, 2, 1, 3, 1)

a_s = np.squeeze(a)
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(a_s.shape)
# (2, 3)

np.squeeze()が返すのは元のndarrayのビュー。元のオブジェクトとビューオブジェクトはメモリを共有するので、一方の要素を変更すると他方の要素も変更される。

print(np.shares_memory(a, a_s))
# True

コピーを生成したい場合はcopy()を使う。

a_s_copy = np.squeeze(a).copy()

print(np.shares_memory(a, a_s_copy))
# False

NumPyにおけるビューとコピーについては以下の記事を参照。

削除対象とする次元を指定: 引数axis

上の例のように、デフォルトではサイズが1の次元がすべて削除される。

np.squeeze()の第二引数axisに削除対象となる次元のインデックスを指定できる。指定していない次元はサイズが1であっても削除されない。

a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a.shape)
# (1, 2, 1, 3, 1)

print(np.squeeze(a, 0))
# [[[[0]
#    [1]
#    [2]]]
# 
# 
#  [[[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, 0).shape)
# (2, 1, 3, 1)

サイズが1ではない次元を指定したり、存在しない次元を指定するとエラー。

# print(np.squeeze(a, 1))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

# print(np.squeeze(a, 5))
# AxisError: axis 5 is out of bounds for array of dimension 5

負の値で指定することも可能。-1が最後の次元に対応し、後ろからの位置で指定できる。

print(np.squeeze(a, -1))
# [[[[0 1 2]]
# 
#   [[3 4 5]]]]

print(np.squeeze(a, -1).shape)
# (1, 2, 1, 3)

print(np.squeeze(a, -3))
# [[[[0]
#    [1]
#    [2]]
# 
#   [[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, -3).shape)
# (1, 2, 3, 1)

タプルで複数の次元を指定できる。サイズが1ではない次元や存在しない次元が含まれているとエラー。

print(np.squeeze(a, (0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(np.squeeze(a, (0, -1)).shape)
# (2, 1, 3)

# print(np.squeeze(a, (0, 1)))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

numpy.ndarray.squeeze()の場合

配列ndarrayのメソッドとしてもsqueeze()が提供されている。

使い方はnp.squeeze()と同じ。第一引数がaxisとなる。

a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a.shape)
# (1, 2, 1, 3, 1)

print(a.squeeze())
# [[0 1 2]
#  [3 4 5]]

print(a.squeeze().shape)
# (2, 3)

print(a.squeeze((0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a.squeeze((0, -1)).shape)
# (2, 1, 3)

np.squeeze()と同じくビューを返す。元のオブジェクトの形状shapeはそのまま。inplaceな処理ではない。

a_s = a.squeeze()
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(np.shares_memory(a, a_s))
# True

print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

関連カテゴリー

関連記事