note.nkmk.me

NumPyで条件に応じた処理を行うwhereの使い方

Date: 2017-02-18 / Modified: 2018-03-08 / tags: Python, NumPy
このエントリーをはてなブックマークに追加

numpy.where()を使うと、NumPy配列ndarrayに対して、条件を満たす要素を置換したり、特定の処理を行ったりすることができる。

条件を満たす要素や行、列を抽出したり削除したりしたい場合は以下の記事を参照。

スポンサーリンク

numpy.where()の概要

numpy.where(condition[, x, y])
Return elements, either from x or y, depending on condition.
If only condition is given, return condition.nonzero().
numpy.where — NumPy v1.14 Manual

numpy.whereは、条件式conditionを満たす場合(真Trueの場合)はx、満たさない場合(偽Falseの場合)はyとするndarrayを返す関数。

x, yを省略した場合は、条件を満たすindexを返す(最後に説明)。

import numpy as np

a = np.arange(9).reshape((3, 3))
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(np.where(a < 4, -1, 100))
# [[ -1  -1  -1]
#  [ -1 100 100]
#  [100 100 100]]

条件を満たす場合はTrue, 満たさない場合はFalseとするのはnp.where()を使わなくても、ndarrayを含む条件式で取得できる。

print(np.where(a < 4, True, False))
# [[ True  True  True]
#  [ True False False]
#  [False False False]]

print(a < 4)
# [[ True  True  True]
#  [ True False False]
#  [False False False]]

複数条件を適用

各条件式を()で囲み&|を使うと、複数条件に対して処理が適用される。

print(np.where((a > 2) & (a < 6), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100 100 100]]

print(np.where((a > 2) & (a < 6) | (a == 7), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100  -1 100]]

複数条件の場合も、True, Falsendarrayを取得するのはnp.where()を使わなくてもよい。

print((a > 2) & (a < 6))
# [[False False False]
#  [ True  True  True]
#  [False False False]]

print((a > 2) & (a < 6) | (a == 7))
# [[False False False]
#  [ True  True  True]
#  [False  True False]]

条件を満たす要素を置換

条件を満たす場合も満たさない場合も任意の値に置換するのはこれまでの例の通り。

条件を満たす場合のみ、あるいは満たさない場合のみ任意の値に置換することもできる。

np.where()の引数x, yに元のndarrayを渡せば、元の値がそのまま使われる。

print(np.where(a < 4, -1, a))
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(np.where(a < 4, a, 100))
# [[  0   1   2]
#  [  3 100 100]
#  [100 100 100]]

なお、np.where()は新たなndarrayを返し、元のndarrayは変更されない。

a_org = np.arange(9).reshape((3, 3))
print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

a_new = np.where(a_org < 4, -1, a_org)
print(a_new)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

元のndarray自体を更新したい場合は以下のような書き方ができる。

a_org[a_org < 4] = -1
print(a_org)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

条件を満たす要素を処理

元のndarrayの値そのままではなく、演算した結果を使うこともできる。

print(np.where(a < 4, a * 10, a))
# [[ 0 10 20]
#  [30  4  5]
#  [ 6  7  8]]

条件を満たす要素のインデックスを取得

引数x, yを省略した場合は、条件を満たす要素のインデックスを返す。

各次元(行、列)に対して条件を満たすインデックス(行番号、列番号)のリストのタプルとなる。

print(np.where(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))

print(type(np.where(a < 4)))
# <class 'tuple'>

この場合、[0, 0][0, 1][0, 2][1, 0]の要素が条件を満たすという意味。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事