NumPy: Split an array with np.split, np.vsplit, np.hsplit, etc.

Posted: | Tags: Python, NumPy

In NumPy, to split an array (ndarray), the following functions are used:

  • np.split(): For splitting into equal parts or at specific positions
  • np.array_split(): For splitting as equally as possible
  • np.vsplit(): For vertical splitting
  • np.hsplit(): For horizontal splitting
  • np.dsplit(): For splitting along the depth

np.split() is the fundamental function, with the others provided for convenience for specific purposes. Understanding np.split() makes it easier to grasp how the others work.

The terms "vertical split" and "horizontal split" might seem ambiguous, but according to the official NumPy documentation, "vertical split" refers to dividing up and down, while "horizontal split" refers to dividing left and right.

It is important to note that the arrays resulting from the split are views of the original array. Changing a value in one will affect the others.

For information on concatenating multiple arrays, refer to the following article.

The NumPy version used in this article is as follows. Note that functionality may vary between versions.

import numpy as np

print(np.__version__)
# 1.26.1

Basic usage of np.split()

The fundamental function for splitting an array (ndarray) is np.split().

Returns a list of arrays

np.split() takes the array to be split as the first argument, and the method of splitting as the second and third arguments.

For example, to split vertically into two equal parts, set the second argument to 2 and omit the third argument (details discussed later).

The function returns a list of arrays, and the original array remains unchanged.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a_split = np.split(a, 2)

print(type(a_split))
# <class 'list'>

print(len(a_split))
# 2

print(a_split[0])
# [[0 1 2 3]
#  [4 5 6 7]]

print(a_split[1])
# [[ 8  9 10 11]
#  [12 13 14 15]]

print(type(a_split[0]))
# <class 'numpy.ndarray'>

print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

It is also possible to use unpacking to store them in separate variables.

a0, a1 = np.split(a, 2)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

Specify the number of splits or split positions: indices_or_sections

Specify the number of splits with an integer

Specifying an integer (int) as the second argument, indices_or_sections, splits the array into that many equal parts. An error occurs if the number does not evenly divide the array along the specified axis.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.split(a, 2)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

# np.split(a, 3)
# ValueError: array split does not result in an equal division

np.array_split() adjusts the number of rows or columns as needed, if the number does not divide evenly, as described later.

Specify split positions with a list

Specifying a list of integers as the second argument, indices_or_sections, splits the array at those index positions, with indexing starting at 0.

For example, specifying [1, 3] splits the array before the 1st row (between the 0th and 1st rows) and before the 3rd row (between the 2nd and 3rd rows).

a0, a1, a2 = np.split(a, [1, 3])

print(a0)
# [[0 1 2 3]]

print(a1)
# [[ 4  5  6  7]
#  [ 8  9 10 11]]

print(a2)
# [[12 13 14 15]]

To split at any specific index, specify a list with a single element.

a0, a1 = np.split(a, [1])

print(a0)
# [[0 1 2 3]]

print(a1)
# [[ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

Specifying an out-of-range index results in an empty array.

a0, a1 = np.split(a, [10])

print(a0)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

print(a1)
# []

print(type(a1))
# <class 'numpy.ndarray'>

Specify the axis to split: axis

The axis (dimension) along which to split the array is specified by the third argument, axis.

Omitting this argument, as in the examples so far, defaults to axis=0. It is also valid to explicitly specify axis=0, which splits the array along the 0th axis, i.e., by rows in 2D arrays.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.split(a, 2, 0)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

Specifying axis=1 splits the array along the 1st axis, i.e., by columns in 2D arrays.

a0, a1 = np.split(a, 2, 1)

print(a0)
# [[ 0  1]
#  [ 4  5]
#  [ 8  9]
#  [12 13]]

print(a1)
# [[ 2  3]
#  [ 6  7]
#  [10 11]
#  [14 15]]

Specifying a non-existent axis results in an error.

# np.split(a, 2, 2)
# IndexError: tuple index out of range

Examples with arrays of three or more dimensions

While the explanations so far have used terms like "rows" and "columns" for simplicity, the same concepts apply to arrays of three or more dimensions.

Consider the following array as an example.

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(a_3d.shape)
# (2, 3, 4)

Use the third argument (axis) to specify the target axis, and the second argument (indices_or_sections) to determine the number of splits or their positions.

a0, a1 = np.split(a_3d, 2, 0)

print(a0)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]]

print(a1)
# [[[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

a0, a1 = np.split(a_3d, [1], 2)

print(a0)
# [[[ 0]
#   [ 4]
#   [ 8]]
# 
#  [[12]
#   [16]
#   [20]]]

print(a1)
# [[[ 1  2  3]
#   [ 5  6  7]
#   [ 9 10 11]]
# 
#  [[13 14 15]
#   [17 18 19]
#   [21 22 23]]]

np.array_split() splits an array as equally as possible

np.array_split() works similarly to np.split(), but allows specifying an integer for indices_or_sections that does not evenly divide the array along the specified axis.

Consider the following array.

a = np.arange(15).reshape(3, 5)
print(a)
# [[ 0  1  2  3  4]
#  [ 5  6  7  8  9]
#  [10 11 12 13 14]]

np.split() results in an error if indices_or_sections does not evenly divide the array.

# np.split(a, 2, 0)
# ValueError: array split does not result in an equal division

np.array_split() adjusts as needed without error. In this example, it increases the number of rows in the first array.

a0, a1 = np.array_split(a, 2, 0)

print(a0)
# [[0 1 2 3 4]
#  [5 6 7 8 9]]

print(a1)
# [[10 11 12 13 14]]

The adjustment follows this rule:

For an array of length l that should be split into n sections, it returns l % n sub-arrays of size l//n + 1 and the rest of size l//n.
numpy.array_split — NumPy v1.26 Manual

This means the remainder from the division is evenly distributed among the first few sub-arrays. In the following example, the first two arrays have one additional column each.

a0, a1, a2 = np.array_split(a, 3, 1)

print(a0)
# [[ 0  1]
#  [ 5  6]
#  [10 11]]

print(a1)
# [[ 2  3]
#  [ 7  8]
#  [12 13]]

print(a2)
# [[ 4]
#  [ 9]
#  [14]]

np.vsplit() splits an array vertically

np.vsplit() splits an array vertically, equivalent to np.split() with axis=0.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.vsplit(a, 2)

print(a0)
# [[0 1 2 3]
#  [4 5 6 7]]

print(a1)
# [[ 8  9 10 11]
#  [12 13 14 15]]

Like np.split(), specifying an integer for the second argument (indices_or_sections) results in an error if the array cannot be evenly divided. Alternatively, specifying a list for indices_or_sections allows for splitting at specific positions.

a0, a1 = np.split(a, [1])

print(a0)
# [[0 1 2 3]]

print(a1)
# [[ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

np.hsplit() splits an array horizontally

np.hsplit() splits an array horizontally, almost equivalent to np.split() with axis=1.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

a0, a1 = np.hsplit(a, 2)

print(a0)
# [[ 0  1]
#  [ 4  5]
#  [ 8  9]
#  [12 13]]

print(a1)
# [[ 2  3]
#  [ 6  7]
#  [10 11]
#  [14 15]]

The second argument (indices_or_sections) works as it does in other functions described above, allowing for either an integer or a list to be specified.

a0, a1 = np.hsplit(a, [1])

print(a0)
# [[ 0]
#  [ 4]
#  [ 8]
#  [12]]

print(a1)
# [[ 1  2  3]
#  [ 5  6  7]
#  [ 9 10 11]
#  [13 14 15]]

Unlike np.split() with axis=1, np.hsplit() can split 1D arrays.

a_1d = np.arange(6)
print(a_1d)
# [0 1 2 3 4 5]

# np.split(a_1d, 2, 1)
# IndexError: tuple index out of range

a0, a1 = np.hsplit(a_1d, 2)

print(a0)
# [0 1 2]

print(a1)
# [3 4 5]

This is because np.hsplit() calls np.split() with axis=0 for 1D arrays. See the following source code.

np.dsplit() splits an array along the depth

np.dsplit() splits an array along the depth direction, equivalent to np.split() with axis=2.

a_3d = np.arange(24).reshape(2, 3, 4)
print(a_3d)
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
# 
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

print(a_3d.shape)
# (2, 3, 4)

a0, a1 = np.dsplit(a_3d, 2)

print(a0)
# [[[ 0  1]
#   [ 4  5]
#   [ 8  9]]
# 
#  [[12 13]
#   [16 17]
#   [20 21]]]

print(a0.shape)
# (2, 3, 2)

print(a1)
# [[[ 2  3]
#   [ 6  7]
#   [10 11]]
# 
#  [[14 15]
#   [18 19]
#   [22 23]]]

print(a1.shape)
# (2, 3, 2)

The second argument (indices_or_sections) works as it does in other functions described above, allowing for either an integer or a list to be specified.

a0, a1 = np.dsplit(a_3d, [1])

print(a0)
# [[[ 0]
#   [ 4]
#   [ 8]]
# 
#  [[12]
#   [16]
#   [20]]]

print(a1)
# [[[ 1  2  3]
#   [ 5  6  7]
#   [ 9 10 11]]
# 
#  [[13 14 15]
#   [17 18 19]
#   [21 22 23]]]

np.dsplit() is intended for arrays of three or more dimensions. Using it on arrays of two dimensions or less results in an error.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

# np.dsplit(a, 2)
# ValueError: dsplit only works on arrays of 3 or more dimensions

Split arrays are views of the original array

Functions such as np.split() return a list of arrays, which are views of the original array.

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

l = np.split(a, 2)

print(l[0])
# [[0 1 2 3]
#  [4 5 6 7]]

print(np.shares_memory(a, l[0]))
# True

Changing a value in one array affects the others.

a[0, 0] = 100
print(a)
# [[100   1   2   3]
#  [  4   5   6   7]
#  [  8   9  10  11]
#  [ 12  13  14  15]]

print(l[0])
# [[100   1   2   3]
#  [  4   5   6   7]]

This is also true for other functions.

print(np.shares_memory(a, np.vsplit(a, 2)[0]))
# True

print(np.shares_memory(a, np.hsplit(a, 2)[0]))
# True

print(np.shares_memory(a, np.array_split(a, 3)[0]))
# True

a_3d = np.arange(24).reshape(2, 3, 4)
print(np.shares_memory(a_3d, np.dsplit(a_3d, 2)[0]))
# True

If independent processing is needed, create a copy of the original array with copy() before applying these functions.

a = np.arange(16).reshape(4, 4)

l_copy = np.split(a.copy(), 2)

print(np.shares_memory(a, l_copy[0]))
# False

a[0, 0] = 100
print(a)
# [[100   1   2   3]
#  [  4   5   6   7]
#  [  8   9  10  11]
#  [ 12  13  14  15]]

print(l_copy[0])
# [[0 1 2 3]
#  [4 5 6 7]]

For more details on views and copies in NumPy, refer to the following article.

Related Categories

Related Articles