NumPy配列ndarrayの最大値・最小値のインデックス(位置)を取得

Modified: | Tags: Python, NumPy

NumPy配列ndarrayで最大値・最小値となる要素のインデックス(位置)を取得するには関数np.argmax(), np.argmin()を使う。それぞれndarrayのメソッドとしても用意されている。

また、np.where()を利用する方法もある。

ここでは以下の内容について説明する。

  • np.argmax()の基本的な使い方
  • 二次元配列(多次元配列)の場合
    • 平坦化された状態のインデックス
    • 各列・各行の最大値のインデックス
    • 全体の最大値のインデックス(座標)
  • ndarrayargmax()メソッド
  • np.argmin(), ndarrayargmin()メソッド
  • np.where()を利用

Python組み込みのリストlistの最大値・最小値のインデックスを取得したい場合は、max(), min()index()を使う。詳細は以下の記事を参照。

np.argmax()の基本的な使い方

以下の一次元配列を例とする。

import numpy as np

a = np.array([1, 100, 10])
print(a)
# [  1 100  10]

np.argmax()の引数にndarrayを指定すると最大値となる要素のインデックスが返される。インデックスは0始まり。

print(np.argmax(a))
# 1

最大値となる要素が複数ある場合は最初のインデックスのみが返される。

a = np.array([1, 10, 10])
print(a)
# [ 1 10 10]

print(np.argmax(a))
# 1

二次元配列(多次元配列)の場合

多次元配列の場合として、二次元配列を例とする。

a_2d = np.array([[20, 50, 30], [60, 40, 10]])
print(a_2d)
# [[20 50 30]
#  [60 40 10]]

平坦化された状態のインデックス

デフォルトでは、平坦化(一次元化)された状態でのインデックスが返される。

print(np.argmax(a_2d))
# 3

ndarrayflatten()メソッドで平坦化できる。デフォルトでは平坦化(一次元化)された状態でのインデックスが返されていることが確認できる。

print(a_2d.flatten())
# [20 50 30 60 40 10]

print(np.argmax(a_2d.flatten()))
# 3

各列・各行の最大値のインデックス

第二引数axisを指定すると、各軸に沿って最大値となるインデックスが返される。デフォルトはaxis=Noneで、上述のように平坦化された状態でのインデックスが返される。

例えばaxis=0とすると各列ごとの最大値の行番号が返される。各列の最大値そのものはnp.max()axis=0とすると得られる。元の配列と見比べて位置を確認されたい。

print(np.argmax(a_2d, axis=0))
# [1 0 0]

print(np.max(a_2d, axis=0))
# [60 50 30]

同様にaxis=1とすると行ごとの最大値の列番号が返される。

print(np.argmax(a_2d, axis=1))
# [1 0]

print(np.max(a_2d, axis=1))
# [50 60]

全体の最大値のインデックス(座標)

全体の最大値のインデックスを取得するにはnp.unravel_index()を使う。

第一引数indicesに平坦化(一次元化)した状態でのインデックス、第二引数shapeに元のndarrayの形状shapeを指定すると、元の形状におけるインデックスがタプルで返される。

上述のようにnp.argmax()のデフォルトでは最大値の平坦化されたインデックスが返されるので、これをnp.unravel_index()の第一引数indicesに指定すればよい。

idx = np.unravel_index(np.argmax(a_2d), a_2d.shape)
print(idx)
# (1, 0)

結果の(1, 0)は1行0列の要素が最大値であることを示している。

得られたインデックスを[]に指定すると最大値が取得できていることが確認できる。

print(a_2d[idx])
# 60

print(np.max(a_2d))
# 60

なお、この方法だと最大値が複数ある場合に最初の要素のインデックスしか取得できない。np.where()を利用すると複数のインデックスを取得可能。最後に説明する。

ndarrayのargmax()メソッド

ここまでは関数np.argmax()の使い方を説明したが、配列ndarrayargmax()メソッドとして使う場合も同様。なお、np.max()にもndarrayに対応するメソッドmax()がある。

a = np.array([1, 100, 10])
print(a)
# [  1 100  10]

print(a.argmax())
# 1

print(a.max())
# 100

多次元配列の場合は目的によって引数axisを指定する。

a_2d = np.array([[20, 50, 30], [60, 40, 10]])
print(a_2d)
# [[20 50 30]
#  [60 40 10]]

print(a_2d.argmax())
# 3

print(a_2d.argmax(axis=0))
# [1 0 0]

print(a_2d.max(axis=0))
# [60 50 30]

print(a_2d.argmax(axis=1))
# [1 0]

print(a_2d.max(axis=1))
# [50 60]

np.unravel_index()と組み合わせて使う場合も同じ。

print(np.unravel_index(a_2d.argmax(), a_2d.shape))
# (1, 0)

np.argmin(), ndarrayのargmin()メソッド

argmin()の場合も考え方はargmax()と同様。関数np.argmin()、配列ndarrayargmin()メソッドがある。

一次元配列の場合。

import numpy as np

a = np.array([1, 100, 10])
print(a)
# [  1 100  10]

print(np.argmin(a))
# 0

print(a.argmin())
# 0

二次元配列の場合。

デフォルトは平坦化された状態でのインデックスが返される。

a_2d = np.array([[20, 50, 30], [60, 40, 10]])
print(a_2d)
# [[20 50 30]
#  [60 40 10]]

print(np.argmin(a_2d))
# 5

print(a_2d.argmin())
# 5

引数axisで行ごと・列ごとのインデックスを取得可能。

print(np.argmin(a_2d, axis=0))
# [0 1 1]

print(a_2d.argmin(axis=0))
# [0 1 1]

print(np.min(a_2d, axis=0))
# [20 40 10]

print(a_2d.min(axis=0))
# [20 40 10]

print(np.argmin(a_2d, axis=1))
# [0 2]

print(a_2d.argmin(axis=1))
# [0 2]

print(np.min(a_2d, axis=1))
# [20 10]

print(a_2d.min(axis=1))
# [20 10]

np.unravel_index()を使う場合。

idx = np.unravel_index(np.argmin(a_2d), a_2d.shape)
print(idx)
# (1, 2)

print(a_2d[idx])
# 10

print(np.min(a_2d))
# 10

print(np.unravel_index(a_2d.argmin(), a_2d.shape))
# (1, 2)

np.where()を利用

上述のように、np.argmax()np.argmin()では、最大値や最小値が複数ある場合に最初の要素のインデックスしか取得できない。

a_2d = np.array([[10, 20, 30], [30, 20, 10]])
print(a_2d)
# [[10 20 30]
#  [30 20 10]]

print(np.max(a_2d))
# 30

print(np.min(a_2d))
# 10

print(np.unravel_index(np.argmax(a_2d), a_2d.shape))
# (0, 2)

print(np.unravel_index(np.argmin(a_2d), a_2d.shape))
# (0, 0)

np.where()を利用すると複数の最大値や最小値のインデックスのリストを取得できる。np.max()np.min()で最大値・最小値を取得し、その値と等しい要素のインデックスを抽出する。

print(np.where(a_2d == np.max(a_2d)))
# (array([0, 1]), array([2, 0]))

print(list(zip(*np.where(a_2d == np.max(a_2d)))))
# [(0, 2), (1, 0)]

print(np.where(a_2d == np.min(a_2d)))
# (array([0, 1]), array([0, 2]))

print(list(zip(*np.where(a_2d == np.min(a_2d)))))
# [(0, 0), (1, 2)]

np.where()list(), zip()による転置を組み合わせている。詳細は以下の記事を参照。

関連カテゴリー

関連記事