NumPy: Meaning of the axis parameter (0, 1, -1)
In NumPy, functions like np.sum()
, np.mean()
, and np.max()
have the axis
parameter, which allows specifying the operation's target: the entire array, column-wise, row-wise, or other dimensions.
The meaning of the term "axis" in NumPy is explained in the official documentation's glossary as follows:
axis
Another term for an array dimension. Axes are numbered left to right; axis 0 is the first element in the shape tuple. Glossary - axis — NumPy v1.26 Manual
This article explains the meaning and usage of the axis
parameter in NumPy.
The NumPy version used in this article is as follows. Note that functionality may vary between versions.
import numpy as np
print(np.__version__)
# 1.26.1
In 2D array, axis=0
operates column-wise, axis=1
operates row-wise
In a two-dimensional array, axis=0
operates column-wise, and axis=1
operates row-wise.
For example, use np.sum()
to calculate the sum.
a = np.arange(12).reshape(3, 4)
print(a)
# [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
print(np.sum(a, axis=0))
# [12 15 18 21]
print(np.sum(a, axis=1))
# [ 6 22 38]
The default is axis=None
, which operates on the entire array.
print(np.sum(a))
# 66
print(np.sum(a, axis=None))
# 66
An error is raised if an axis outside the array's dimensions is specified.
# print(np.sum(a, axis=2))
# AxisError: axis 2 is out of bounds for array of dimension 2
The above example uses np.sum()
, but the same applies to np.mean()
, np.max()
, np.min()
, and so on.
axis=-1
represents the last axis
You can use negative values for the axis
parameter, which allows specifying an axis in reverse order from the last one. -1
represents the last axis.
In a two-dimensional array, axis=-1
is equivalent to axis=1
, and axis=-2
is equivalent to axis=0
.
a = np.arange(12).reshape(3, 4)
print(a)
# [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
print(np.sum(a, axis=-1))
# [ 6 22 38]
print(np.sum(a, axis=-2))
# [12 15 18 21]
An error is raised if an axis outside the array's dimensions is specified.
# print(np.sum(a, axis=-3))
# AxisError: axis -3 is out of bounds for array of dimension 2
Meaning of the axis
parameter
For two-dimensional arrays, it is enough to remember that axis=0
operates column-wise and axis=1
operates row-wise, but let's consider the meaning of the axis
parameter more generally.
Here, np.sum()
is used as an example, but the same concept applies to other functions and methods like np.mean()
.
In the case of 2D array
Consider a two-dimensional array with the shape (2, 3)
.
a = np.arange(6).reshape(2, 3)
print(a)
# [[0 1 2]
# [3 4 5]]
np.sum()
returns a one-dimensional array with the shape (3,)
for axis=0
, and (2,)
for axis=1
.
print(np.sum(a, axis=0))
# [3 5 7]
print(np.sum(a, axis=1))
# [ 3 12]
axis=0
aggregates along the first axis, and axis=1
aggregates along the next axis, reducing the specified axis while preserving the other axis.
input : (2, 3)
axis=0 : (_, 3) -> (3,)
axis=1 : (2, _) -> (2,)
np.sum(axis=0)
and np.sum(axis=1)
are equivalent to the following operations.
print(a[0, :] + a[1, :])
# [3 5 7]
print(a[:, 0] + a[:, 1] + a[:, 2])
# [ 3 12]
:
is used to select all elements along a given axis. The trailing :
can be omitted, but is included here for clarity in the above example.
In the case of 3D array
The same applies to three-dimensional and higher multi-dimensional arrays. Consider a three-dimensional array with the shape (2, 3, 4)
.
a = np.stack([np.ones((3, 4), int), np.full((3, 4), 10)])
print(a)
# [[[ 1 1 1 1]
# [ 1 1 1 1]
# [ 1 1 1 1]]
#
# [[10 10 10 10]
# [10 10 10 10]
# [10 10 10 10]]]
print(a.shape)
# (2, 3, 4)
np.sum()
returns a two-dimensional array with the shape (3, 4)
for axis=0
, (2, 4)
for axis=1
, and (2, 3)
for axis=2
.
print(np.sum(a, axis=0))
# [[11 11 11 11]
# [11 11 11 11]
# [11 11 11 11]]
print(np.sum(a, axis=1))
# [[ 3 3 3 3]
# [30 30 30 30]]
print(np.sum(a, axis=2))
# [[ 4 4 4]
# [40 40 40]]
This operation aggregates the input array along the axis specified by axis
, reducing the specified axis while preserving the other axes.
input : (2, 3, 4)
axis=0 : (_, 3, 4) -> (3, 4)
axis=1 : (2, _, 4) -> (2, 4)
axis=2 : (2, 3, _) -> (2, 3)
np.sum(axis=0)
, np.sum(axis=1)
, and np.sum(axis=2)
are equivalent to the following operations.
print(a[0, :, :] + a[1, :, :])
# [[11 11 11 11]
# [11 11 11 11]
# [11 11 11 11]]
print(a[:, 0, :] + a[:, 1, :] + a[:, 2, :])
# [[ 3 3 3 3]
# [30 30 30 30]]
print(a[:, :, 0] + a[:, :, 1] + a[:, :, 2] + a[:, :, 3])
# [[ 4 4 4]
# [40 40 40]]
When dealing with three-dimensional or higher multi-dimensional arrays, it may be more helpful to consider which axes to keep in the shape, rather than thinking in terms of rows, columns, depth, etc.
When specifying multiple values for axis
You can specify multiple values for the axis
parameter with a tuple. The same concept applies here as well.
a = np.stack([np.ones((3, 4), int), np.full((3, 4), 10)])
print(a)
# [[[ 1 1 1 1]
# [ 1 1 1 1]
# [ 1 1 1 1]]
#
# [[10 10 10 10]
# [10 10 10 10]
# [10 10 10 10]]]
print(a.shape)
# (2, 3, 4)
print(np.sum(a, axis=(0, 1)))
# [33 33 33 33]
print(np.sum(a, axis=(0, 2)))
# [44 44 44]
print(np.sum(a, axis=(1, 2)))
# [ 12 120]
The relationship between the shapes of the input and output arrays can be considered as follows:
input : (2, 3, 4)
axis=(0, 1) : (_, _, 4) -> (4,)
axis=(0, 2) : (_, 3, _) -> (3,)
axis=(1, 2) : (2, _, _) -> (2,)
Each is equivalent to the following operations.
print(
a[0, 0, :] + a[0, 1, :] + a[0, 2, :] +
a[1, 0, :] + a[1, 1, :] + a[1, 2, :]
)
# [33 33 33 33]
print(
a[0, :, 0] + a[0, :, 1] + a[0, :, 2] + a[0, :, 3] +
a[1, :, 0] + a[1, :, 1] + a[1, :, 2] + a[1, :, 3]
)
# [44 44 44]
print(
a[:, 0, 0] + a[:, 0, 1] + a[:, 0, 2] + a[:, 0, 3] +
a[:, 1, 0] + a[:, 1, 1] + a[:, 1, 2] + a[:, 1, 3] +
a[:, 2, 0] + a[:, 2, 1] + a[:, 2, 2] + a[:, 2, 3]
)
# [ 12 120]
keepdims
maintains the dimensions of the output array
Functions and methods with the axis
parameter also support the keepdims
parameter. keepdims=True
maintains the same number of dimensions in the output array as in the input array.
np.sum()
is used as an example, but the same applies to other functions and methods like np.mean()
.
The default output for a two-dimensional input array is one-dimensional, but keepdims=True
returns a two-dimensional array.
a = np.ones((3, 4), int)
print(a)
# [[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]
print(a.shape)
# (3, 4)
print(np.sum(a, axis=1))
# [4 4 4]
print(np.sum(a, axis=1).shape)
# (3,)
print(np.sum(a, axis=1, keepdims=True))
# [[4]
# [4]
# [4]]
print(np.sum(a, axis=1, keepdims=True).shape)
# (3, 1)
With keepdims=True
, the output array is correctly broadcast with the input array.
In NumPy, during binary operations, array shapes are automatically aligned through broadcasting where possible.
In operations involving the output array and the input array (or any array with the same shape as the input), using axis=1
with default setting may result in an error. However, keepdims=True
ensures proper broadcasting.
# print(a + np.sum(a, axis=1))
# ValueError: operands could not be broadcast together with shapes (3,4) (3,)
print(a + np.sum(a, axis=1, keepdims=True))
# [[5 5 5 5]
# [5 5 5 5]
# [5 5 5 5]]
For axis=0
, the default setting correctly broadcasts as well, but keepdims=True
is also acceptable.
print(a + np.sum(a, axis=0))
# [[4 4 4 4]
# [4 4 4 4]
# [4 4 4 4]]
print(a + np.sum(a, axis=0, keepdims=True))
# [[4 4 4 4]
# [4 4 4 4]
# [4 4 4 4]]
For operations involving broadcasting to the input array, it may be safer to set keepdims=True
to avoid mistakes.
The same applies to three-dimensional and higher multi-dimensional arrays. Regardless of the value specified for axis
, keepdims=True
ensures that the output array is correctly broadcast with the input array.
a = np.ones((2, 3, 4), int)
print(a)
# [[[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]
#
# [[1 1 1 1]
# [1 1 1 1]
# [1 1 1 1]]]
print(np.sum(a, axis=(0, 2)))
# [8 8 8]
print(np.sum(a, axis=(0, 2), keepdims=True))
# [[[8]
# [8]
# [8]]]
# print(a + np.sum(a, axis=(0, 2)))
# ValueError: operands could not be broadcast together with shapes (2,3,4) (3,)
print(a + np.sum(a, axis=(0, 2), keepdims=True))
# [[[9 9 9 9]
# [9 9 9 9]
# [9 9 9 9]]
#
# [[9 9 9 9]
# [9 9 9 9]
# [9 9 9 9]]]