NumPyで条件に応じた処理を行うnp.whereの使い方
numpy.where()
を使うと、NumPy配列ndarray
の条件を満たす要素・満たさない要素を置換したり、特定の処理を適用したりできる。条件を満たす要素のインデックス(位置)を取得することも可能。
条件を満たす要素・行・列の抽出や削除については以下の記事を参照。
本記事のサンプルコードのNumPyのバージョンは以下の通り。バージョンによって仕様が異なる可能性があるので注意。
import numpy as np
print(np.__version__)
# 1.26.1
numpy.where()の基本的な使い方
numpy.where(condition, x, y)
は、第一引数condition
の真の要素を第二引数x
、偽の要素を第三引数y
に置き換えた配列ndarray
を返す。
a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
print(a < 4)
# [[ True True True]
# [ True False False]
# [False False False]]
print(np.where(a < 4, -1, 100))
# [[ -1 -1 -1]
# [ -1 100 100]
# [100 100 100]]
第一引数condition
には、ndarray
だけでなくPython組み込みのリストlist
などのいわゆるarray-likeオブジェクトを指定可能。リストなどを指定しても返り値は常にndarray
。
print(np.where([True, False, True], -1, 100))
# [ -1 100 -1]
ブール値(True
, False
)以外を要素とする配列やリストは、各要素の真偽が判定されて処理される。例えば、数値の場合は0
が偽、それ以外の値はすべて真と判定される。
print(np.where([-2, -1, 0, 1, 2], -1, 100))
# [ -1 -1 100 -1 -1]
第二引数x
, 第三引数y
を省略すると、条件を満たすインデックスが返される。これについては最後に説明する。
複数条件を適用
複数の条件を組み合わせたい場合は、各条件式を括弧()
で囲んで&
(AND)や|
(OR)でつなぐ。否定~
(NOT)も使用可能。
a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
print((a > 2) & (a < 6))
# [[False False False]
# [ True True True]
# [False False False]]
print(np.where((a > 2) & (a < 6), -1, 100))
# [[100 100 100]
# [ -1 -1 -1]
# [100 100 100]]
print((a > 2) & (a < 6) | (a == 7))
# [[False False False]
# [ True True True]
# [False True False]]
print(np.where((a > 2) & (a < 6) | (a == 7), -1, 100))
# [[100 100 100]
# [ -1 -1 -1]
# [100 -1 100]]
&
, |
ではなくand
, or
を使ったり、括弧を省略したりするとエラーになるので注意。
条件を満たす要素を置換
条件を満たす要素、または、満たさない要素のみを置換
条件を満たす要素・満たさない要素の両方を任意の値に置換するのはこれまでの例の通り。条件を満たす要素のみ、あるいは満たさない要素のみを置換することもできる。
np.where()
の第二引数x
, 第三引数y
にはスカラー値ではなく配列ndarray
も指定可能。元のndarray
を渡せば、元の値がそのまま使われる。
a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
print(np.where(a < 4, -1, a))
# [[-1 -1 -1]
# [-1 4 5]
# [ 6 7 8]]
print(np.where(a < 4, a, 100))
# [[ 0 1 2]
# [ 3 100 100]
# [100 100 100]]
なお、np.where()
は新たなndarray
を返し、元のndarray
は変更されない。
a_org = np.arange(9).reshape(3, 3)
print(a_org)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
a_new = np.where(a_org < 4, -1, a_org)
print(a_new)
# [[-1 -1 -1]
# [-1 4 5]
# [ 6 7 8]]
print(a_org)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
元のndarray
自体を更新したい場合は以下のような書き方ができる。
a_org[a_org < 4] = -1
print(a_org)
# [[-1 -1 -1]
# [-1 4 5]
# [ 6 7 8]]
条件を満たす要素に処理を適用
元のndarray
の値そのままではなく、処理を適用した結果に置き換えることもできる。
a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
print(np.where(a < 4, a * 10, a + 100))
# [[ 0 10 20]
# [ 30 104 105]
# [106 107 108]]
ブロードキャストの影響
np.where()
の第二引数x
, 第三引数y
に指定したndarray
は、第一引数condition
に指定したndarray
と同じ形状にブロードキャストされる。
- 関連記事: NumPyのブロードキャスト(形状の自動変換)
例えば、sum()
やmean()
で列・行ごとの合計や平均を算出して列・行ごとに異なる値に置換したい場合は、keepdims=True
とすると正しくブロードキャストされて置換される。
a = np.arange(12).reshape(3, 4)
print(a)
# [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
print(a.sum(axis=0, keepdims=True))
# [[12 15 18 21]]
print(np.where(a < 6, a.sum(axis=0, keepdims=True), 0))
# [[12 15 18 21]
# [12 15 0 0]
# [ 0 0 0 0]]
print(a.sum(axis=1, keepdims=True))
# [[ 6]
# [22]
# [38]]
print(np.where(a < 6, a.sum(axis=1, keepdims=True), 0))
# [[ 6 6 6 6]
# [22 22 0 0]
# [ 0 0 0 0]]
axis=0
ではkeepdims=False
(デフォルト)でも問題ないが、axis
に関わらずkeepdims=True
としておいたほうが間違いは少ない。
条件を満たす要素のインデックス(位置)を取得
np.where()
の第二引数x
, 第三引数y
を省略すると、条件を満たす要素のインデックス(位置)が返される。
各次元(行、列など)の条件を満たすインデックスを表すndarray
のタプルとなる。
a = np.arange(9).reshape(3, 3)
print(a)
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
print(np.where(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))
print(type(np.where(a < 4)))
# <class 'tuple'>
print(type(np.where(a < 4)[0]))
# <class 'numpy.ndarray'>
この場合、(0, 0)
, (0, 1)
, (0, 2)
, (1, 0)
の座標の要素が条件を満たすという意味。
np.nonzero()
でも同様の処理が可能。
print(np.nonzero(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))
list()
, zip()
および*
による要素の展開を組み合わせて各座標のリストを取得できる。
print(list(zip(*np.where(a < 4))))
# [(0, 0), (0, 1), (0, 2), (1, 0)]
np.argwhere()
を使うと、同様の結果をndarray
で取得できる。
print(np.argwhere(a < 4))
# [[0 0]
# [0 1]
# [0 2]
# [1 0]]
三次元配列の例は以下の通り。
a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
#
# [[12 13 14 15]
# [16 17 18 19]
# [20 21 22 23]]]
print(a_3d % 5 == 0)
# [[[ True False False False]
# [False True False False]
# [False False True False]]
#
# [[False False False True]
# [False False False False]
# [ True False False False]]]
print(np.where(a_3d % 5 == 0))
# (array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 2]), array([0, 1, 2, 3, 0]))
print(np.nonzero(a_3d % 5 == 0))
# (array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 2]), array([0, 1, 2, 3, 0]))
print(list(zip(*np.where(a_3d % 5 == 0))))
# [(0, 0, 0), (0, 1, 1), (0, 2, 2), (1, 0, 3), (1, 2, 0)]
print(np.argwhere(a_3d % 5 == 0))
# [[0 0 0]
# [0 1 1]
# [0 2 2]
# [1 0 3]
# [1 2 0]]
一次元配列の例は以下の通り。
a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]
print(np.where(a_1d < 3))
# (array([0, 1, 2]),)
print(np.nonzero(a_1d < 3))
# (array([0, 1, 2]),)
print(list(zip(*np.where(a_1d < 3))))
# [(0,), (1,), (2,)]
print(np.argwhere(a_1d < 3))
# [[0]
# [1]
# [2]]
(0,)
などは要素数1個のタプルを表す。