NumPyで条件に応じた処理を行うnp.whereの使い方

Modified: | Tags: Python, NumPy

numpy.where()を使うと、NumPy配列ndarrayの条件を満たす要素・満たさない要素を置換したり、特定の処理を適用したりできる。条件を満たす要素のインデックス(位置)を取得することも可能。

条件を満たす要素・行・列の抽出や削除については以下の記事を参照。

本記事のサンプルコードのNumPyのバージョンは以下の通り。バージョンによって仕様が異なる可能性があるので注意。

import numpy as np

print(np.__version__)
# 1.26.1

numpy.where()の基本的な使い方

numpy.where(condition, x, y)は、第一引数conditionの真の要素を第二引数x、偽の要素を第三引数yに置き換えた配列ndarrayを返す。

a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(a < 4)
# [[ True  True  True]
#  [ True False False]
#  [False False False]]

print(np.where(a < 4, -1, 100))
# [[ -1  -1  -1]
#  [ -1 100 100]
#  [100 100 100]]

第一引数conditionには、ndarrayだけでなくPython組み込みのリストlistなどのいわゆるarray-likeオブジェクトを指定可能。リストなどを指定しても返り値は常にndarray

print(np.where([True, False, True], -1, 100))
# [ -1 100  -1]

ブール値(True, False)以外を要素とする配列やリストは、各要素の真偽が判定されて処理される。例えば、数値の場合は0が偽、それ以外の値はすべて真と判定される。

print(np.where([-2, -1, 0, 1, 2], -1, 100))
# [ -1  -1 100  -1  -1]

第二引数x, 第三引数yを省略すると、条件を満たすインデックスが返される。これについては最後に説明する。

複数条件を適用

複数の条件を組み合わせたい場合は、各条件式を括弧()で囲んで&(AND)や|(OR)でつなぐ。否定~(NOT)も使用可能。

a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print((a > 2) & (a < 6))
# [[False False False]
#  [ True  True  True]
#  [False False False]]

print(np.where((a > 2) & (a < 6), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100 100 100]]

print((a > 2) & (a < 6) | (a == 7))
# [[False False False]
#  [ True  True  True]
#  [False  True False]]

print(np.where((a > 2) & (a < 6) | (a == 7), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100  -1 100]]

&, |ではなくand, orを使ったり、括弧を省略したりするとエラーになるので注意。

条件を満たす要素を置換

条件を満たす要素、または、満たさない要素のみを置換

条件を満たす要素・満たさない要素の両方を任意の値に置換するのはこれまでの例の通り。条件を満たす要素のみ、あるいは満たさない要素のみを置換することもできる。

np.where()の第二引数x, 第三引数yにはスカラー値ではなく配列ndarrayも指定可能。元のndarrayを渡せば、元の値がそのまま使われる。

a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(np.where(a < 4, -1, a))
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(np.where(a < 4, a, 100))
# [[  0   1   2]
#  [  3 100 100]
#  [100 100 100]]

なお、np.where()は新たなndarrayを返し、元のndarrayは変更されない。

a_org = np.arange(9).reshape(3, 3)
print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

a_new = np.where(a_org < 4, -1, a_org)
print(a_new)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

元のndarray自体を更新したい場合は以下のような書き方ができる。

a_org[a_org < 4] = -1
print(a_org)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

条件を満たす要素に処理を適用

元のndarrayの値そのままではなく、処理を適用した結果に置き換えることもできる。

a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(np.where(a < 4, a * 10, a + 100))
# [[  0  10  20]
#  [ 30 104 105]
#  [106 107 108]]

ブロードキャストの影響

np.where()の第二引数x, 第三引数yに指定したndarrayは、第一引数conditionに指定したndarrayと同じ形状にブロードキャストされる。

例えば、sum()mean()で列・行ごとの合計や平均を算出して列・行ごとに異なる値に置換したい場合は、keepdims=Trueとすると正しくブロードキャストされて置換される。

a = np.arange(12).reshape(3, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]]

print(a.sum(axis=0, keepdims=True))
# [[12 15 18 21]]

print(np.where(a < 6, a.sum(axis=0, keepdims=True), 0))
# [[12 15 18 21]
#  [12 15  0  0]
#  [ 0  0  0  0]]

print(a.sum(axis=1, keepdims=True))
# [[ 6]
#  [22]
#  [38]]

print(np.where(a < 6, a.sum(axis=1, keepdims=True), 0))
# [[ 6  6  6  6]
#  [22 22  0  0]
#  [ 0  0  0  0]]

axis=0ではkeepdims=False(デフォルト)でも問題ないが、axisに関わらずkeepdims=Trueとしておいたほうが間違いは少ない。

条件を満たす要素のインデックス(位置)を取得

np.where()の第二引数x, 第三引数yを省略すると、条件を満たす要素のインデックス(位置)が返される。

各次元(行、列など)の条件を満たすインデックスを表すndarrayのタプルとなる。

a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(np.where(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))

print(type(np.where(a < 4)))
# <class 'tuple'>

print(type(np.where(a < 4)[0]))
# <class 'numpy.ndarray'>

この場合、(0, 0), (0, 1), (0, 2), (1, 0)の座標の要素が条件を満たすという意味。

np.nonzero()でも同様の処理が可能。

print(np.nonzero(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))

list(), zip()および*による要素の展開を組み合わせて各座標のリストを取得できる。

print(list(zip(*np.where(a < 4))))
# [(0, 0), (0, 1), (0, 2), (1, 0)]

np.argwhere()を使うと、同様の結果をndarrayで取得できる。

print(np.argwhere(a < 4))
# [[0 0]
#  [0 1]
#  [0 2]
#  [1 0]]

三次元配列の例は以下の通り。

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(a_3d % 5 == 0)
# [[[ True False False False]
#   [False  True False False]
#   [False False  True False]]
# 
#  [[False False False  True]
#   [False False False False]
#   [ True False False False]]]

print(np.where(a_3d % 5 == 0))
# (array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 2]), array([0, 1, 2, 3, 0]))

print(np.nonzero(a_3d % 5 == 0))
# (array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 2]), array([0, 1, 2, 3, 0]))

print(list(zip(*np.where(a_3d % 5 == 0))))
# [(0, 0, 0), (0, 1, 1), (0, 2, 2), (1, 0, 3), (1, 2, 0)]

print(np.argwhere(a_3d % 5 == 0))
# [[0 0 0]
#  [0 1 1]
#  [0 2 2]
#  [1 0 3]
#  [1 2 0]]

一次元配列の例は以下の通り。

a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]

print(np.where(a_1d < 3))
# (array([0, 1, 2]),)

print(np.nonzero(a_1d < 3))
# (array([0, 1, 2]),)

print(list(zip(*np.where(a_1d < 3))))
# [(0,), (1,), (2,)]

print(np.argwhere(a_1d < 3))
# [[0]
#  [1]
#  [2]]

(0,)などは要素数1個のタプルを表す。

関連カテゴリー

関連記事