NumPy配列ndarrayを分割(split, array_split, hsplit, vsplit, dsplit)

Modified: | Tags: Python, NumPy

NumPy配列ndarrayを分割するには以下の関数を使う。

  • numpy.split(): 等分割、または、任意の位置で分割
  • numpy.array_split(): できるだけ等分割で分割
  • numpy.vsplit(): 縦に分割
  • numpy.hsplit(): 横に分割
  • numpy.dsplit(): 深さ方向に分割

numpy.split()が基本で、他は特定の目的のために簡単に使えるようにしたもの。numpy.split()の使い方を理解しておけば他も理解しやすい。

「縦に分割」「横に分割」という表現は曖昧だが、ここではNumPyの公式ドキュメントに則り、上下に分割することを縦に分割、左右に分割することを横に分割と表す。

最後に説明するように、分割された配列は元の配列のビュー。一方の要素を変更すると他方の要素も変更されるので注意。

複数の配列ndarrayの結合(連結)については以下の記事を参照。

本記事のサンプルコードのNumPyのバージョンは以下の通り。バージョンによって仕様が異なる可能性があるので注意。

import numpy as np

print(np.__version__)
# 1.26.1

numpy.split()の基本的な使い方

配列ndarrayを分割する基本的な関数がnumpy.split()

戻り値は配列ndarrayのリスト

numpy.split()の第一引数に分割する配列ndarray、第二引数・第三引数に分割方法を指定する。

縦に二等分で分割する例は以下の通り。第二引数を2とし、第三引数を省略する(詳細は後述)。

戻り値は配列ndarrayのリストとなる。元の配列は変更されない。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a_split = np.split(a, 2)

print(type(a_split))
# <class 'list'>

print(len(a_split))
# 2

print(a_split[0])
# [[0 1 2 3]
#  [4 5 6 7]]

print(a_split[1])
# [[ 8  9 10 11]
#  [12 13 14 15]]

print(type(a_split[0]))
# <class 'numpy.ndarray'>

print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

アンパックを利用して別々の変数に格納することも可能。

a0, a1 = np.split(a, 2)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

分割個数または分割位置を指定: 引数indices_or_sections

分割個数を整数で指定

第二引数indices_or_sectionsに整数intを指定するとその個数で等分割される。割り切れない値を指定するとエラーとなる。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.split(a, 2)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

# np.split(a, 3)
# ValueError: array split does not result in an equal division

後述のnumpy.array_split()を使うと、割り切れない場合に適当に行数や列数を調整してくれる。

分割位置をリスト(配列)で指定

第二引数indices_or_sectionsに整数のリスト(配列)を指定すると、そのインデックス(位置)で分割される。インデックスは0始まりで指定。

例えば[1, 3]とすると、1行目の前(0行目と1行目の間)と3行目の前(2行目と3行目の間)で分割される。

a0, a1, a2 = np.split(a, [1, 3])

print(a0)
# [[0 1 2 3]]

print(a1)
# [[ 4  5  6  7]
#  [ 8  9 10 11]]

print(a2)
# [[12 13 14 15]]

任意のインデックスで二分割したい場合は要素数1のリストを指定する。

a0, a1 = np.split(a, [1])

print(a0)
# [[0 1 2 3]]

print(a1)
# [[ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

範囲外のインデックスを指定すると、空のndarrayが返される。

a0, a1 = np.split(a, [10])

print(a0)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

print(a1)
# []

print(type(a1))
# <class 'numpy.ndarray'>

分割する軸(次元)を指定: 引数axis

分割する軸(次元)は第三引数axisで指定する。

これまでの例のように第三引数を省略するとデフォルトのaxis=0となる。もちろんaxis=0と明示的に指定してもよい。この場合、0次元目の軸に沿って、すなわち、行ごとに分割される。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.split(a, 2, 0)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

axis=1とすると、1次元目の軸に沿って、つまり、列ごとに分割される。

a0, a1 = np.split(a, 2, 1)

print(a0)
# [[ 0  1]
#  [ 4  5]
#  [ 8  9]
#  [12 13]]

print(a1)
# [[ 2  3]
#  [ 6  7]
#  [10 11]
#  [14 15]]

存在しない次元を指定するとエラーとなる。

# np.split(a, 2, 2)
# IndexError: tuple index out of range

三次元以上の多次元配列の場合の例

ここまでは、便宜上、行・列といった文言で説明したが、三次元以上の場合も考え方は同じ。以下の配列を例とする。

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(a_3d.shape)
# (2, 3, 4)

第三引数axisに分割する対象の軸(次元)を指定し、第二引数indices_or_sectionsに整数またはリストで分割個数や位置を指定する。

a0, a1 = np.split(a_3d, 2, 0)

print(a0)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]]

print(a1)
# [[[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

a0, a1 = np.split(a_3d, [1], 2)

print(a0)
# [[[ 0]
#   [ 4]
#   [ 8]]
# 
#  [[12]
#   [16]
#   [20]]]

print(a1)
# [[[ 1  2  3]
#   [ 5  6  7]
#   [ 9 10 11]]
# 
#  [[13 14 15]
#   [17 18 19]
#   [21 22 23]]]

numpy.array_split()でできるだけ等分割で分割

numpy.split()で第二引数indices_or_sectionsに整数で分割個数を指定する場合は割り切れないとエラーになるが、numpy.array_split()だと割り切れないときに適当に行数や列数を調整してくれる。それ以外はnumpy.split()と同じ動作。引数の設定なども同じ。

以下の配列を例とする。

a = np.arange(15).reshape(3, 5)
print(a)
# [[ 0  1  2  3  4]
#  [ 5  6  7  8  9]
#  [10 11 12 13 14]]

numpy.split()だと割り切れない(等分割できない)場合にエラーとなる。

# np.split(a, 2, 0)
# ValueError: array split does not result in an equal division

numpy.array_split()だとエラーにならずに適当に調整してくれる。この例だと先頭の配列の行数を一つ増やして分割される。

a0, a1 = np.array_split(a, 2, 0)

print(a0)
# [[0 1 2 3 4]
#  [5 6 7 8 9]]

print(a1)
# [[10 11 12 13 14]]

行数・列数は以下のルールに従って調整される。

For an array of length l that should be split into n sections, it returns l % n sub-arrays of size l//n + 1 and the rest of size l//n.
numpy.array_split — NumPy v1.26 Manual

余り分を分割後の配列の先頭から振り分けていくイメージ。以下の例では先頭二つの配列の列数を一つ増やして分割される。

a0, a1, a2 = np.array_split(a, 3, 1)

print(a0)
# [[ 0  1]
#  [ 5  6]
#  [10 11]]

print(a1)
# [[ 2  3]
#  [ 7  8]
#  [12 13]]

print(a2)
# [[ 4]
#  [ 9]
#  [14]]

numpy.vsplit()で縦に分割

numpy.vsplit()は配列を縦(vertical)に分割する。axis=0numpy.split()と等価。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.vsplit(a, 2)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

numpy.split()と同じく、第二引数indices_or_sectionsに整数で分割個数を指定する場合、割り切れないとエラーになる。

第二引数indices_or_sectionsにリスト(配列)を指定して位置で分割することも可能。

a0, a1 = np.split(a, [1])

print(a0)
# [[0 1 2 3]]

print(a1)
# [[ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

numpy.hsplit()で横に分割

numpy.hsplit()は配列を横(horizontal)に分割する。axis=1numpy.split()と等価。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.hsplit(a, 2)

print(a0)
# [[ 0  1]
#  [ 4  5]
#  [ 8  9]
#  [12 13]]

print(a1)
# [[ 2  3]
#  [ 6  7]
#  [10 11]
#  [14 15]]

numpy.split()と同じく、第二引数indices_or_sectionsに整数で分割個数を指定する場合、割り切れないとエラーになる。

第二引数indices_or_sectionsにリスト(配列)を指定して位置で分割することも可能。

a0, a1 = np.hsplit(a, [1])

print(a0)
# [[ 0]
#  [ 4]
#  [ 8]
#  [12]]

print(a1)
# [[ 1  2  3]
#  [ 5  6  7]
#  [ 9 10 11]
#  [13 14 15]]

axis=1numpy.split()と異なり、numpy.hsplit()は一次元配列に対してもエラーにならない。

a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]

# np.split(a_1d, 2, 1)
# IndexError: tuple index out of range

a0, a1 = np.hsplit(a_1d, 2)

print(a0)
# [0 1 2]

print(a1)
# [3 4 5]

これはnumpy.hsplit()は一次元配列に対してはaxis=0としてnumpy.split()を呼んでいるから。以下のソースを参照。

numpy.dsplit()で深さ方向に分割

numpy.dsplit()は配列を深さ(depth)方向に分割する関数。axis=2numpy.split()と等価。

print()の出力だとイメージしにくいかもしれないが、形状shapeを見ると二次元目で分割されていることが確認できる。

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(a_3d.shape)
# (2, 3, 4)

a0, a1 = np.dsplit(a_3d, 2)

print(a0)
# [[[ 0  1]
#   [ 4  5]
#   [ 8  9]]
# 
#  [[12 13]
#   [16 17]
#   [20 21]]]

print(a0.shape)
# (2, 3, 2)

print(a1)
# [[[ 2  3]
#   [ 6  7]
#   [10 11]]
# 
#  [[14 15]
#   [18 19]
#   [22 23]]]

print(a1.shape)
# (2, 3, 2)

numpy.split()と同じく、第二引数indices_or_sectionsに整数で分割個数を指定する場合、割り切れないとエラーになる。

第二引数indices_or_sectionsにリスト(配列)を指定して位置で分割することも可能。

a0, a1 = np.dsplit(a_3d, [1])

print(a0)
# [[[ 0]
#   [ 4]
#   [ 8]]
# 
#  [[12]
#   [16]
#   [20]]]

print(a1)
# [[[ 1  2  3]
#   [ 5  6  7]
#   [ 9 10 11]]
# 
#  [[13 14 15]
#   [17 18 19]
#   [21 22 23]]]

numpy.dsplit()は三次元以上の配列が対象。二次元以下の配列に対してはエラーとなる。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

# np.dsplit(a, 2)
# ValueError: dsplit only works on arrays of 3 or more dimensions

分割された配列は元の配列のビュー

これまで説明したnumpy.split()などの関数は、分割された配列ndarrayのリストを返すが、それらの配列は元の配列のビュー。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

l = np.split(a, 2)

print(l[0])
# [[0 1 2 3]
#  [4 5 6 7]]

print(np.shares_memory(a, l[0]))
# True

一方の要素を変更すると、他方の要素も変更される。

a[0, 0] = 100
print(a)
# [[100   1   2   3]
#  [  4   5   6   7]
#  [  8   9  10  11]
#  [ 12  13  14  15]]

print(l[0])
# [[100   1   2   3]
#  [  4   5   6   7]]

他の関数でも同様。

print(np.shares_memory(a, np.vsplit(a, 2)[0]))
# True

print(np.shares_memory(a, np.hsplit(a, 2)[0]))
# True

print(np.shares_memory(a, np.array_split(a, 3)[0]))
# True

a_3d = np.arange(24).reshape(2, 3, 4)
print(np.shares_memory(a_3d, np.dsplit(a_3d, 2)[0]))
# True

別々に処理したい場合は、copy()で元の配列のコピーを生成して各関数に渡せばよい。

a = np.arange(16).reshape(4, 4)

l_copy = np.split(a.copy(), 2)

print(np.shares_memory(a, l_copy[0]))
# False

a[0, 0] = 100
print(a)
# [[100   1   2   3]
#  [  4   5   6   7]
#  [  8   9  10  11]
#  [ 12  13  14  15]]

print(l_copy[0])
# [[0 1 2 3]
#  [4 5 6 7]]

NumPyにおけるビューとコピーについての詳細は以下の記事を参照。

関連カテゴリー

関連記事