note.nkmk.me

NumPy配列ndarrayに次元を追加するnp.newaxis, np.expand_dims()

Date: 2019-02-14 / Modified: 2019-12-12 / tags: Python, NumPy

NumPy配列ndarrayに新たな次元を追加する(次元を増やす)には、np.newaxis, np.expand_dims()およびnp.reshape()(またはndarrayreshape()メソッド)を使う方法がある。

ここでは以下の内容について説明する。

  • np.newaxisの使い方
    • np.newaxisNone
    • np.newaxisで新たな次元を追加
    • np.newaxisでブロードキャストを制御
  • np.expand_dims()で新たな次元を追加
  • np.reshape()で新たな次元を追加

np.reshape()あるいはndarrayreshape()メソッドは次元を追加するだけでなく任意の形状shapeへの変更が可能。本記事の最後でも触れるが、詳細は以下の記事を参照。

np.newaxisnp.expand_dims()ではサイズ1の新たな次元を追加できるが、反対にサイズ1の次元を削除するにはnp.squeeze()を使う。

スポンサーリンク

np.newaxisの使い方

np.newaxisはNone

np.newaxisNoneのエイリアス。

import numpy as np

print(np.newaxis is None)
# True

分かりやすくするように別名が付けられているだけなので、以下のサンプルコードのnp.newaxisNoneに置き換えても同じように動作する。

np.newaxisで新たな次元を追加

[]によるインデックスの中でnp.newaxisを使うとその位置にサイズが1の新たな次元が追加される。

a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
#  [3 4 5]]

print(a.shape)
# (2, 3)

print(a[:, :, np.newaxis])
# [[[0]
#   [1]
#   [2]]
# 
#  [[3]
#   [4]
#   [5]]]

print(a[:, :, np.newaxis].shape)
# (2, 3, 1)
print(a[:, np.newaxis, :])
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a[:, np.newaxis, :].shape)
# (2, 1, 3)
print(a[np.newaxis, :, :])
# [[[0 1 2]
#   [3 4 5]]]

print(a[np.newaxis, :, :].shape)
# (1, 2, 3)

[]内の末尾の:は省略可能。

print(a[:, np.newaxis])
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a[:, np.newaxis].shape)
# (2, 1, 3)
print(a[np.newaxis])
# [[[0 1 2]
#   [3 4 5]]]

print(a[np.newaxis].shape)
# (1, 2, 3)

また、連続する:...で置き換えることができる。例のように三次元くらいだと:でもそこまで面倒ではないが、次元数が多いndarrayの末尾に次元を追加するような場合は...を使うと楽。

print(a[..., np.newaxis])
# [[[0]
#   [1]
#   [2]]
# 
#  [[3]
#   [4]
#   [5]]]

print(a[..., np.newaxis].shape)
# (2, 3, 1)

複数のnp.newaxisを一度に使ってもOK。複数の次元が追加される。

print(a[np.newaxis, :, np.newaxis, :, np.newaxis])
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a[np.newaxis, :, np.newaxis, :, np.newaxis].shape)
# (1, 2, 1, 3, 1)

np.newaxisによる次元追加で返されるのは元のオブジェクトのビュー。元のオブジェクトとビューオブジェクトはメモリを共有するので、いずれかの要素を変更すると他方の要素も変更される。

a_newaxis = a[:, :, np.newaxis]

print(np.shares_memory(a, a_newaxis))
# True

np.newaxisでブロードキャストを制御

NumPy配列ndarray同士の二項演算(四則演算など)ではブロードキャスト(broadcasting)という仕組みによりそれぞれの形状shapeが同じになるように自動的に変換される。

a = np.zeros(27, dtype=np.int).reshape(3, 3, 3)
print(a)
# [[[0 0 0]
#   [0 0 0]
#   [0 0 0]]
# 
#  [[0 0 0]
#   [0 0 0]
#   [0 0 0]]
# 
#  [[0 0 0]
#   [0 0 0]
#   [0 0 0]]]

print(a.shape)
# (3, 3, 3)

b = np.arange(9).reshape(3, 3)
print(b)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(b.shape)
# (3, 3)

print(a + b)
# [[[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]]

詳細は上の関連記事を参照されたいが、ブロードキャストでは次元数を揃えるために次元数が少ない方の配列の先頭に新たな次元を加えるというルールがある。

np.newaxisで先頭に新たな次元を追加した場合はブロードキャストにより自動的に変換された場合と同じ結果になる。

print(b[np.newaxis, :, :].shape)
# (1, 3, 3)

print(a + b[np.newaxis, :, :])
# [[[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]
# 
#  [[0 1 2]
#   [3 4 5]
#   [6 7 8]]]

追加する位置を変えると異なる結果となる。

print(b[:, np.newaxis, :].shape)
# (3, 1, 3)

print(a + b[:, np.newaxis, :])
# [[[0 1 2]
#   [0 1 2]
#   [0 1 2]]
# 
#  [[3 4 5]
#   [3 4 5]
#   [3 4 5]]
# 
#  [[6 7 8]
#   [6 7 8]
#   [6 7 8]]]
print(b[:, :, np.newaxis].shape)
# (3, 3, 1)

print(a + b[:, :, np.newaxis])
# [[[0 0 0]
#   [1 1 1]
#   [2 2 2]]
# 
#  [[3 3 3]
#   [4 4 4]
#   [5 5 5]]
# 
#  [[6 6 6]
#   [7 7 7]
#   [8 8 8]]]

例えば、カラー画像の配列(形状: (高さ, 幅, 色))と単色画像(形状: (高さ, 幅))の配列を同じ位置同士で足したり引いたりしたい場合、そのままだとブロードキャストできずにエラーとなるが、単色画像の最後に新たな次元を追加するとうまくいく。このあたりの説明も以下の記事を参照。

np.expand_dims()で新たな次元を追加

ndarrayに新たな次元を追加する方法として、np.expand_dims()関数を使う方法もある。

第一引数aに元のndarray、第二引数axisに次元を追加する位置を指定する。

a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
#  [3 4 5]]

print(np.expand_dims(a, 0))
# [[[0 1 2]
#   [3 4 5]]]

print(np.expand_dims(a, 0).shape)
# (1, 2, 3)

以下のように、任意の位置に新たな次元を挿入できる。形状shapeのみを示す。

print(np.expand_dims(a, 0).shape)
# (1, 2, 3)

print(np.expand_dims(a, 1).shape)
# (2, 1, 3)

print(np.expand_dims(a, 2).shape)
# (2, 3, 1)

第二引数axisには負の値も指定可能。-1が最後の次元に対応し、後ろからの位置を指定できる。

print(np.expand_dims(a, -1).shape)
# (2, 3, 1)

print(np.expand_dims(a, -2).shape)
# (2, 1, 3)

print(np.expand_dims(a, -3).shape)
# (1, 2, 3)

NumPy1.17時点では、第二引数axisaxis > a.ndimまたはaxis < -a.ndim - 1ndimは次元数)となる値を指定してもエラーにはならず、末尾または先頭に次元が追加される。ただし、警告メッセージにあるように将来的にはエラーとなるとのことなので避けたほうがいいだろう。

print(np.expand_dims(a, 3).shape)
# (2, 3, 1)
# 
# /usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:1: DeprecationWarning: Both axis > a.ndim and axis < -a.ndim - 1 are deprecated and will raise an AxisError in the future.
#   """Entry point for launching an IPython kernel.

print(np.expand_dims(a, -4).shape)
# (2, 1, 3)
# 
# /usr/local/lib/python3.7/site-packages/ipykernel_launcher.py:1: DeprecationWarning: Both axis > a.ndim and axis < -a.ndim - 1 are deprecated and will raise an AxisError in the future.
#   """Entry point for launching an IPython kernel.

また、第二引数axisに指定できるのは整数値のみ。タプルなどで複数の位置を指定して複数の次元を一度に追加することはできない。

# print(np.expand_dims(a, (0, 1)).shape)
# TypeError: '>' not supported between instances of 'tuple' and 'int'

np.newaxisと同様に、np.expand_dims()もビューを返す。

a_expand_dims = np.expand_dims(a, 0)

print(np.shares_memory(a, a_expand_dims))
# True

例は省略するが、上述のnp.newaxisのように、np.expand_dims()で新たな次元を追加してブロードキャストを制御することももちろん可能。

np.reshape()で新たな次元を追加

配列ndarrayの形状shapeを変換する方法としてnp.reshape()関数がある。reshape()ndarrayのメソッドとしても提供されている。reshape()も(可能な限り)ビューを返す。詳細は以下の記事を参照。

reshape()に新たな次元を追加した形状を指定すれば当然ながらそのように変換される。np.newaxisnp.expand_dims()を使った場合と同じ結果となる。

a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
#  [3 4 5]]

print(a.shape)
# (2, 3)

print(a[np.newaxis])
# [[[0 1 2]
#   [3 4 5]]]

print(a[np.newaxis].shape)
# (1, 2, 3)

print(np.expand_dims(a, 0))
# [[[0 1 2]
#   [3 4 5]]]

print(np.expand_dims(a, 0).shape)
# (1, 2, 3)

print(a.reshape(1, 2, 3))
# [[[0 1 2]
#   [3 4 5]]]

print(a.reshape(1, 2, 3).shape)
# (1, 2, 3)

上の例からも分かるように、np.newaxisnp.expand_dims()を使うと追加する次元以外の次元のサイズ(元の次元のサイズ)を明示的に指定する必要がない、というメリットがある。

reshape()でも、先頭か末尾に次元を追加する場合は、*を使ってタプルを展開して引数に指定することで、サイズを明示的に指定しなくてもよくなる。お好みで。

print(a.reshape(1, *a.shape))
# [[[0 1 2]
#   [3 4 5]]]

print(a.reshape(1, *a.shape).shape)
# (1, 2, 3)
スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事