NumPy配列ndarrayに次元を追加するnp.newaxis, np.expand_dims()

Modified: | Tags: Python, NumPy

NumPy配列ndarrayに新たな次元を追加する(次元を増やす)には、np.newaxis, np.expand_dims()およびnp.reshape()(またはndarrayreshape()メソッド)を使う方法がある。

np.reshape()あるいはndarrayreshape()メソッドは次元を追加するだけでなく任意の形状shapeへの変換が可能。本記事の最後でも触れるが、詳細は以下の記事を参照。

np.newaxisnp.expand_dims()ではサイズ1の新たな次元を追加できるが、反対にサイズ1の次元を削除するにはnp.squeeze()を使う。

本記事のサンプルコードのNumPyのバージョンは以下の通り。バージョンによって仕様が異なる可能性があるので注意。

import numpy as np

print(np.__version__)
# 1.26.1

np.newaxisの使い方

np.newaxisはNone

np.newaxisNoneのエイリアス。

print(np.newaxis is None)
# True

分かりやすくするように別名が付けられているだけなので、以下のサンプルコードのnp.newaxisNoneに置き換えても同じように動作する。

np.newaxisで新たな次元を追加

[]によるインデックスの中でnp.newaxisを使うと、その位置にサイズが1の新たな次元が追加される。

a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
#  [3 4 5]]

print(a.shape)
# (2, 3)

print(a[:, :, np.newaxis])
# [[[0]
#   [1]
#   [2]]
# 
#  [[3]
#   [4]
#   [5]]]

print(a[:, :, np.newaxis].shape)
# (2, 3, 1)
print(a[:, np.newaxis, :])
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a[:, np.newaxis, :].shape)
# (2, 1, 3)
print(a[np.newaxis, :, :])
# [[[0 1 2]
#   [3 4 5]]]

print(a[np.newaxis, :, :].shape)
# (1, 2, 3)

[]内の末尾の:は省略可能。先頭に次元を追加する場合は[np.newaxis]でよい。

print(a[:, np.newaxis])
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a[:, np.newaxis].shape)
# (2, 1, 3)
print(a[np.newaxis])
# [[[0 1 2]
#   [3 4 5]]]

print(a[np.newaxis].shape)
# (1, 2, 3)

連続する:...で置き換えることができる。例のように三次元くらいだと:でもそこまで面倒ではないが、次元数が多いndarrayの末尾に次元を追加する場合は...を使うと楽。

print(a[..., np.newaxis])
# [[[0]
#   [1]
#   [2]]
# 
#  [[3]
#   [4]
#   [5]]]

print(a[..., np.newaxis].shape)
# (2, 3, 1)

複数のnp.newaxisを同時に使ってもよい。複数の次元が追加される。

print(a[np.newaxis, :, np.newaxis, :, np.newaxis])
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a[np.newaxis, :, np.newaxis, :, np.newaxis].shape)
# (1, 2, 1, 3, 1)

np.newaxisによる次元追加で返されるのは元のオブジェクトのビュー。元のオブジェクトとビューオブジェクトはメモリを共有するので、一方の要素を変更すると他方の要素も変更される。

a_newaxis = a[:, :, np.newaxis]

print(np.shares_memory(a, a_newaxis))
# True

np.newaxisでブロードキャストを制御

NumPy配列ndarray同士の二項演算(四則演算など)では、ブロードキャストという仕組みによって、それぞれの形状shapeが同じになるように自動的に変換される。

a = np.zeros(27, dtype=np.int64).reshape(3, 3, 3)
print(a)
# [[[0 0 0]
#   [0 0 0]
#   [0 0 0]]
# 
#  [[0 0 0]
#   [0 0 0]
#   [0 0 0]]
# 
#  [[0 0 0]
#   [0 0 0]
#   [0 0 0]]]

print(a.shape)
# (3, 3, 3)

b = np.arange(9).reshape(3, 3)
print(b)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(b.shape)
# (3, 3)

print(a + b)
# [[[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]]

詳細は上の関連記事を参照されたいが、ブロードキャストでは次元数を揃えるために次元数が少ない方の配列の先頭に新たな次元を加えるというルールがある。

np.newaxisで先頭に新たな次元を追加した場合はブロードキャストにより自動的に変換された場合と同じ結果になる。

print(b[np.newaxis, :, :].shape)
# (1, 3, 3)

print(a + b[np.newaxis, :, :])
# [[[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]]

次元を追加する位置を変えると異なる結果となる。

print(b[:, np.newaxis, :].shape)
# (3, 1, 3)

print(a + b[:, np.newaxis, :])
# [[[0 1 2]
#   [0 1 2]
#   [0 1 2]]
# 
#  [[3 4 5]
#   [3 4 5]
#   [3 4 5]]
# 
#  [[6 7 8]
#   [6 7 8]
#   [6 7 8]]]
print(b[:, :, np.newaxis].shape)
# (3, 3, 1)

print(a + b[:, :, np.newaxis])
# [[[0 0 0]
#   [1 1 1]
#   [2 2 2]]
# 
#  [[3 3 3]
#   [4 4 4]
#   [5 5 5]]
# 
#  [[6 6 6]
#   [7 7 7]
#   [8 8 8]]]

例えば、カラー画像の配列(形状: (高さ, 幅, 色))と単色画像の配列(形状: (高さ, 幅))を同じ位置同士で足したり引いたりしたい場合、そのままだとブロードキャストできずにエラーとなるが、単色画像の最後に新たな次元を追加するとうまくいく。このあたりの説明も以下の記事を参照。

np.expand_dims()で新たな次元を追加

ndarrayに新たな次元を追加する方法として、np.expand_dims()関数を使う方法もある。

第一引数aに元のndarray、第二引数axisに次元を追加する位置を指定する。

a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
#  [3 4 5]]

print(np.expand_dims(a, 0))
# [[[0 1 2]
#   [3 4 5]]]

print(np.expand_dims(a, 0).shape)
# (1, 2, 3)

以下のように、任意の位置に新たな次元を挿入できる。形状shapeのみを示す。

print(np.expand_dims(a, 0).shape)
# (1, 2, 3)

print(np.expand_dims(a, 1).shape)
# (2, 1, 3)

print(np.expand_dims(a, 2).shape)
# (2, 3, 1)

第二引数axisには負の値も指定可能。-1が最後の次元に対応し、後ろからの位置を指定できる。

print(np.expand_dims(a, -1).shape)
# (2, 3, 1)

print(np.expand_dims(a, -2).shape)
# (2, 1, 3)

print(np.expand_dims(a, -3).shape)
# (1, 2, 3)

NumPy1.17までは、第二引数axisに範囲外の値を指定してもエラーにならず、末尾または先頭に次元が追加されていたが、NumPy1.18以降はエラーとなる。

# print(np.expand_dims(a, 3).shape)
# AxisError: axis 3 is out of bounds for array of dimension 3

# print(np.expand_dims(a, -4).shape)
# AxisError: axis -4 is out of bounds for array of dimension 3

また、NumPy1.18以降、第二引数axisにタプルで複数の位置を指定して複数の次元を一度に追加できるようになった。

print(np.expand_dims(a, (0, 1, -1)).shape)
# (1, 1, 2, 3, 1)

np.newaxisと同様に、np.expand_dims()もビューを返す。

a_expand_dims = np.expand_dims(a, 0)

print(np.shares_memory(a, a_expand_dims))
# True

例は省略するが、上述のnp.newaxisのように、np.expand_dims()で新たな次元を追加してブロードキャストを制御することももちろん可能。

np.reshape()で新たな次元を追加

配列ndarrayの形状shapeを変換する方法としてnp.reshape()関数がある。reshape()ndarrayのメソッドとしても提供されている。reshape()も(可能な限り)ビューを返す。詳細は以下の記事を参照。

reshape()に新たな次元を追加した形状を指定すれば当然ながらそのように変換される。np.newaxisnp.expand_dims()を使った場合と同じ結果となる。

a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
#  [3 4 5]]

print(a.shape)
# (2, 3)

print(a[np.newaxis])
# [[[0 1 2]
#   [3 4 5]]]

print(a[np.newaxis].shape)
# (1, 2, 3)

print(np.expand_dims(a, 0))
# [[[0 1 2]
#   [3 4 5]]]

print(np.expand_dims(a, 0).shape)
# (1, 2, 3)

print(a.reshape(1, 2, 3))
# [[[0 1 2]
#   [3 4 5]]]

print(a.reshape(1, 2, 3).shape)
# (1, 2, 3)

上の例からも分かるように、np.newaxisnp.expand_dims()を使うと、追加する次元以外の次元のサイズ(元の次元のサイズ)を明示的に指定する必要がない、というメリットがある。

関連カテゴリー

関連記事