note.nkmk.me

How to flatten a list of lists in Python

Posted: 2020-07-21 / Modified: 2021-09-01 / Tags: Python, List

This article describes how to flatten a list of lists (a multi-dimensional list, nested list) in Python, you can use itertools.chain.from_iterable(), sum(), and list comprehensions.

  • Flatten 2D list (a list of lists)
    • Flatten list with itertools.chain.from_iterable()
    • Flatten list with sum()
    • Flatten list with list comprehensions
    • Speed comparison
  • Flatten 3D and more multidimensional lists and irregular lists

Use ravel() or flatten() to flatten a NumPy array ndarray.

On the contrary, see the following article about how to convert a one-dimensional ndarray or list to two dimensions.

Sponsored Link

Flatten 2D list (a list of lists)

Flatten list with itertools.chain.from_iterable()

A list of lists (2D list) can be flattened by itertools.chain.from_iterable().

import itertools

l_2d = [[0, 1], [2, 3]]

print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]

itertools.chain.from_iterable() returns an iterator, so if you want to convert it to a list, use list(). It is not necessary to make a list when using it in a for statement.

A tuple of tuples can be handled in the same way. In the following example, the result is converted to a tuple by using tuple(). If you need a list, use list().

t_2d = ((0, 1), (2, 3))

print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)

Only 2D lists can be flattened with itertools.chain.from_iterable(). In case of 3D or more multidimensional lists, the result is as follows.

l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

An error occurs if it contains non-iterable objects such as int or float.

l_mix = [[0, 1], [2, 3], 4]

# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable

Cases with 3D or more multidimensional lists and irregular list are described later.

Flatten list with sum()

You can also use the built-in function sum() to flatten list.

An initial value can be specified as the second argument of sum(). If you pass the empty list [], the list's + operation will concatenate lists.

Note that the default value of the second argument is 0, so if omitted, an error will occur due to + operation with int and list.

print(sum(l_2d, []))
# [0, 1, 2, 3]

# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'

Tuples can be handled in the same way.

print(sum(t_2d, ()))
# (0, 1, 2, 3)

Like itertools.chain.from_iterable(), it doesn't work for more than 3D or irregular list.

print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list

Flatten list with list comprehensions

You can also use nested list comprehensions.

matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

flat = [x for row in matrix for x in row]
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]

The example above is equivalent to the following nested for loop.

flat = []
for row in matrix:
    for x in row:
        flat.append(x)

print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]

In the case of the list comprehension above, as with the other methods, only one level can be flattened, and an error will occur if non-iterable objects are included

It is also possible to make the nesting deeper to support more than three dimensions, or to make conditional branching depending on the type of the element, but this is not recommended because it would be too complicated.

See the following article for more information on list comprehensions.

Speed comparison

Note that although sum() is easy to use, it is much slower than itertools.chain.from_iterable() or list comprehensions when the number of lines (the number of inner lists) is large. It is better not to use sum() in situations where the number of lines is large and processing speed and memory efficiency are important.

Although you have to import itertools, itertools.chain.from_iterable() is faster than the list comprehensions.

The following code is measured using the magic command %%timeit on Jupyter Notebook. Note that it doesn't work on Python script.

5 lines:

l_2d_5 = [[0, 1, 2] for i in range(5)]
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]

%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 537 ns ± 4.59 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
sum(l_2d_5, [])
# 319 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
[x for row in l_2d_5 for x in row]
# 764 ns ± 32.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

100 lines:

l_2d_100 = [[0, 1, 2] for i in range(100)]

%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 6.94 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%%timeit
sum(l_2d_100, [])
# 35.5 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit
[x for row in l_2d_100 for x in row]
# 13.5 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

10000 lines:

l_2d_10000 = [[0, 1, 2] for i in range(10000)]

%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 552 µs ± 79.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sum(l_2d_10000, [])
# 343 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
[x for row in l_2d_10000 for x in row]
# 1.11 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Flatten 3D and more multidimensional lists and irregular lists

It is necessary to define a new function to flatten 3D and more multidimensional lists or irregular lists.

The following sample code is based on the following article.

import collections

def flatten(l):
    for el in l:
        if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

The type of element el is checked by isinstance() and processed recursively.

Determine if el is iterable by collections.abc.Iterable. You need to import the standard library collections.

The string str and the byte string bytes are also iterable, so they are excluded. If not excluded, it will be separated for each character.

This function can be used in all cases.

print(list(flatten(l_2d)))
# [0, 1, 2, 3]

print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]

It doesn't matter if various iterable objects such as lists, tuples, and range are included.

l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]

print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]

If you only want to handle list, you don't need to import collections. Tuples and range are not flattened, but in most cases this will be sufficient.

def flatten_list(l):
    for el in l:
        if isinstance(el, list):
            yield from flatten_list(el)
        else:
            yield el

print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]

print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]

print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]

You can specify multiple types by tuple in the second argument of isinstance().

def flatten_list_tuple_range(l):
    for el in l:
        if isinstance(el, (list, tuple, range)):
            yield from flatten_list_tuple_range(el)
        else:
            yield el

print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
Sponsored Link
Share

Related Categories

Related Articles