note.nkmk.me

numpy.where(): Process elements depending on conditions

Posted: 2019-05-29 / Modified: 2019-11-05 / Tags: Python, NumPy

Using numpy.where(), elements of the NumPy array ndarray that satisfy the conditions can be replaced or performed specified processing.

This post describes the following contents.

  • Overview of np.where()
  • Multiple conditions
  • Replace the elements that satisfy the condition
  • Process the elements that satisfy the condition
  • Get the indices of the elements that satisfy the condition

If you want to extract or delete elements, rows and columns that satisfy the conditions, see the following post.

Sponsored Link

Overview of np.where()

numpy.where(condition[, x, y])
Return elements, either from x or y, depending on condition.
If only condition is given, return condition.nonzero().
numpy.where — NumPy v1.14 Manual

np.where() is a function that returns ndarray which is x if condition is True and y if False. x, y and condition need to be broadcastable to same shape.

If x andy are omitted, index is returned. This will be described later.

import numpy as np

a = np.arange(9).reshape((3, 3))
print(a)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

print(np.where(a < 4, -1, 100))
# [[ -1  -1  -1]
#  [ -1 100 100]
#  [100 100 100]]

The bool value ndarray can be obtained by a conditional expression including ndarray without using np.where().

print(a < 4)
# [[ True  True  True]
#  [ True False False]
#  [False False False]]

Multiple conditions

If each conditional expression is enclosed in () and & or | is used, processing is applied to multiple conditions.

print(np.where((a > 2) & (a < 6), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100 100 100]]

print(np.where((a > 2) & (a < 6) | (a == 7), -1, 100))
# [[100 100 100]
#  [ -1  -1  -1]
#  [100  -1 100]]

Even in the case of multiple conditions, it is not necessary to use np.where() to obtain bool value ndarray.

print((a > 2) & (a < 6))
# [[False False False]
#  [ True  True  True]
#  [False False False]]

print((a > 2) & (a < 6) | (a == 7))
# [[False False False]
#  [ True  True  True]
#  [False  True False]]

Replace the elements that satisfy the condition

It is also possible to replace elements with an arbitrary value only when the condition is satisfied or only when the condition is not satisfied.

If you pass the original ndarray to x and y, the original value is used as it is.

print(np.where(a < 4, -1, a))
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(np.where(a < 4, a, 100))
# [[  0   1   2]
#  [  3 100 100]
#  [100 100 100]]

Note that np.where() returns a new ndarray, and the original ndarray is unchanged.

a_org = np.arange(9).reshape((3, 3))
print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

a_new = np.where(a_org < 4, -1, a_org)
print(a_new)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]

print(a_org)
# [[0 1 2]
#  [3 4 5]
#  [6 7 8]]

If you want to update the original ndarray itself, you can write:

a_org[a_org < 4] = -1
print(a_org)
# [[-1 -1 -1]
#  [-1  4  5]
#  [ 6  7  8]]
Sponsored Link

Process the elements that satisfy the condition

Instead of the original ndarray, you can also specify the result of the operation (calculation) as x, y.

print(np.where(a < 4, a * 10, a))
# [[ 0 10 20]
#  [30  4  5]
#  [ 6  7  8]]

Get the indices of the elements that satisfy the condition

If x and y are omitted, the indices of the elements satisfying the condition is returned.

A tuple of an array of indices (row number, column number) that satisfy the condition for each dimension (row, column) is returned.

print(np.where(a < 4))
# (array([0, 0, 0, 1]), array([0, 1, 2, 0]))

print(type(np.where(a < 4)))
# <class 'tuple'>

In this case, it means that the elements at [0, 0], [0, 1], [0, 2] and [1, 0] satisfy the condition.

It is also possible to obtain a list of each coordinate by using list(), zip() and * as follows.

print(list(zip(*np.where(a < 4))))
# [(0, 0), (0, 1), (0, 2), (1, 0)]

The same applies to multi-dimensional arrays of three or more dimensions.

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(np.where(a_3d < 5))
# (array([0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1]), array([0, 1, 2, 3, 0]))

print(list(zip(*np.where(a_3d < 5))))
# [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 1, 0)]

The same applies to one-dimensional arrays. Note that using list(), zip(), and *, each element in the resulting list is a tuple with one element.

a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]

print(np.where(a_1d < 3))
# (array([0, 1, 2]),)

print(list(zip(*np.where(a_1d < 3))))
# [(0,), (1,), (2,)]

If you know that it is one-dimensional, you can use the first element of the result of np.where() as it is. In this case, it will be a ndarray with an integer int as an element, not a tuple with one element. If you want to convert to a list, use tolist().

print(np.where(a_1d < 3)[0])
# [0 1 2]

print(np.where(a_1d < 3)[0].tolist())
# [0, 1, 2]

The number of dimensions can be obtained with the ndim attribute.

Sponsored Link
Share

Related Categories

Related Posts