note.nkmk.me

NumPy配列ndarrayのサイズ1の次元を削除するnp.squeeze()

Date: 2019-12-12 / tags: Python, NumPy

NumPy配列ndarrayのサイズ(大きさ)が1の次元をまとめて削除するにはnumpy.squeeze()を使う。numpy.ndarrayのメソッドとしても提供されている。

ここでは、以下の内容について説明する。

  • numpy.squeeze()の基本的な使い方
  • 削除対象とする次元を指定: 引数axis
  • numpy.ndarray.squeeze()の場合

任意の形状shapeに変換したい場合はnumpy.reshape()、サイズ1の新たな次元を追加したい場合はnumpy.newaxis, numpy.expand_dims()を使う。以下の記事を参照。

スポンサーリンク

numpy.squeeze()の基本的な使い方

numpy.squeeze()の第一引数にnumpy.ndarrayを指定すると、サイズが1の次元がすべて削除された形状shapenumpy.ndarrayが返される。

import numpy as np

a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a.shape)
# (1, 2, 1, 3, 1)

a_s = np.squeeze(a)
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(a_s.shape)
# (2, 3)

numpy.squeeze()が返すのは元のnumpy.ndarrayのビュー。元のオブジェクトとビューオブジェクトはメモリを共有するので、いずれかの要素を変更すると他方の要素も変更される。

print(np.shares_memory(a, a_s))
# True

コピーを生成したい場合はcopy()を使う。

a_s_copy = np.squeeze(a).copy()

print(np.shares_memory(a, a_s_copy))
# False

NumPyにおけるビューとコピーについては以下の記事を参照。

削除対象とする次元を指定: 引数axis

上の例のように、デフォルトではサイズが1の次元がすべて削除される。

numpy.squeeze()の第二引数axisに削除対象となる次元のインデックスを指定できる。指定したインデックスではない次元はサイズが1であっても削除されない。

print(np.squeeze(a, 0))
# [[[[0]
#    [1]
#    [2]]]
# 
# 
#  [[[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, 0).shape)
# (2, 1, 3, 1)

サイズが1ではない次元を指定したり、存在しない次元を指定するとエラー。

# print(np.squeeze(a, 1))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

# print(np.squeeze(a, 5))
# AxisError: axis 5 is out of bounds for array of dimension 5

負の値で指定することも可能。-1が最後の次元に対応し、後ろからの位置で指定できる。

print(np.squeeze(a, -1))
# [[[[0 1 2]]
# 
#   [[3 4 5]]]]

print(np.squeeze(a, -1).shape)
# (1, 2, 1, 3)

print(np.squeeze(a, -3))
# [[[[0]
#    [1]
#    [2]]
# 
#   [[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, -3).shape)
# (1, 2, 3, 1)

タプルで複数の次元を指定できる。サイズが1ではない次元や存在しない次元が含まれているとエラー。

print(np.squeeze(a, (0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(np.squeeze(a, (0, -1)).shape)
# (2, 1, 3)

# print(np.squeeze(a, (0, 1)))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

numpy.ndarray.squeeze()の場合

numpy.ndarrayのメソッドとしてもsqueeze()が提供されている。

使い方はnumpy.squeeze()と同じ。第一引数がaxisとなる。

print(a.squeeze())
# [[0 1 2]
#  [3 4 5]]

print(a.squeeze().shape)
# (2, 3)

print(a.squeeze((0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a.squeeze((0, -1)).shape)
# (2, 1, 3)

numpy.squeeze()と同じくビューを返す。元のオブジェクトの形状shapeはそのまま。inplaceな処理ではない。

a_s = a.squeeze()
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(np.shares_memory(a, a_s))
# True

print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]
スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事