# How to flatten a list of lists in Python

Posted: 2020-07-21 / Modified: 2021-09-01 / Tags: Python, List

This article describes how to flatten a list of lists (a multi-dimensional list, nested list) in Python, you can use `itertools.chain.from_iterable()`, `sum()`, and list comprehensions.

• Flatten 2D list (a list of lists)
• Flatten list with `itertools.chain.from_iterable()`
• Flatten list with `sum()`
• Flatten list with list comprehensions
• Speed comparison
• Flatten 3D and more multidimensional lists and irregular lists

Use `ravel()` or `flatten()` to flatten a NumPy array `ndarray`.

On the contrary, see the following article about how to convert a one-dimensional `ndarray` or `list` to two dimensions.

## Flatten 2D list (a list of lists)

### Flatten list with `itertools.chain.from_iterable()`

A list of lists (2D list) can be flattened by `itertools.chain.from_iterable()`.

``````import itertools

l_2d = [[0, 1], [2, 3]]

print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]
``````

`itertools.chain.from_iterable()` returns an iterator, so if you want to convert it to a list, use `list()`. It is not necessary to make a list when using it in a `for` statement.

A tuple of tuples can be handled in the same way. In the following example, the result is converted to a tuple by using `tuple()`. If you need a list, use `list()`.

``````t_2d = ((0, 1), (2, 3))

print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)
``````

Only 2D lists can be flattened with `itertools.chain.from_iterable()`. In case of 3D or more multidimensional lists, the result is as follows.

``````l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
``````

An error occurs if it contains non-iterable objects such as `int` or `float`.

``````l_mix = [[0, 1], [2, 3], 4]

# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable
``````

Cases with 3D or more multidimensional lists and irregular list are described later.

### Flatten list with `sum()`

You can also use the built-in function `sum()` to flatten list.

An initial value can be specified as the second argument of `sum()`. If you pass the empty list `[]`, the list's `+` operation will concatenate lists.

Note that the default value of the second argument is `0`, so if omitted, an error will occur due to `+` operation with `int` and `list`.

``````print(sum(l_2d, []))
# [0, 1, 2, 3]

# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'
``````

Tuples can be handled in the same way.

``````print(sum(t_2d, ()))
# (0, 1, 2, 3)
``````

Like `itertools.chain.from_iterable()`, it doesn't work for more than 3D or irregular list.

``````print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list
``````

### Flatten list with list comprehensions

You can also use nested list comprehensions.

``````matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

flat = [x for row in matrix for x in row]
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
``````

The example above is equivalent to the following nested `for` loop.

``````flat = []
for row in matrix:
for x in row:
flat.append(x)

print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
``````

In the case of the list comprehension above, as with the other methods, only one level can be flattened, and an error will occur if non-iterable objects are included

It is also possible to make the nesting deeper to support more than three dimensions, or to make conditional branching depending on the type of the element, but this is not recommended because it would be too complicated.

### Speed comparison

Note that although `sum()` is easy to use, it is much slower than `itertools.chain.from_iterable()` or list comprehensions when the number of lines (the number of inner lists) is large. It is better not to use `sum()` in situations where the number of lines is large and processing speed and memory efficiency are important.

Although you have to import itertools, `itertools.chain.from_iterable()` is faster than the list comprehensions.

The following code is measured using the magic command `%%timeit` on Jupyter Notebook. Note that it doesn't work on Python script.

5 lines:

``````l_2d_5 = [[0, 1, 2] for i in range(5)]
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]

%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 537 ns ± 4.59 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
sum(l_2d_5, [])
# 319 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
[x for row in l_2d_5 for x in row]
# 764 ns ± 32.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
``````

100 lines:

``````l_2d_100 = [[0, 1, 2] for i in range(100)]

%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 6.94 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%%timeit
sum(l_2d_100, [])
# 35.5 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%%timeit
[x for row in l_2d_100 for x in row]
# 13.5 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
``````

10000 lines:

``````l_2d_10000 = [[0, 1, 2] for i in range(10000)]

%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 552 µs ± 79.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sum(l_2d_10000, [])
# 343 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
[x for row in l_2d_10000 for x in row]
# 1.11 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
``````

## Flatten 3D and more multidimensional lists and irregular lists

It is necessary to define a new function to flatten 3D and more multidimensional lists or irregular lists.

The following sample code is based on the following article.

``````import collections

def flatten(l):
for el in l:
if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
yield from flatten(el)
else:
yield el
``````

The type of element `el` is checked by `isinstance()` and processed recursively.

Determine if `el` is iterable by `collections.abc.Iterable`. You need to import the standard library `collections`.

The string `str` and the byte string `bytes` are also iterable, so they are excluded. If not excluded, it will be separated for each character.

This function can be used in all cases.

``````print(list(flatten(l_2d)))
# [0, 1, 2, 3]

print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]
``````

It doesn't matter if various iterable objects such as lists, tuples, and `range` are included.

``````l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]

print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
``````

If you only want to handle list, you don't need to import `collections`. Tuples and `range` are not flattened, but in most cases this will be sufficient.

``````def flatten_list(l):
for el in l:
if isinstance(el, list):
yield from flatten_list(el)
else:
yield el

print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]

print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]

print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]
``````

You can specify multiple types by tuple in the second argument of `isinstance()`.

``````def flatten_list_tuple_range(l):
for el in l:
if isinstance(el, (list, tuple, range)):
yield from flatten_list_tuple_range(el)
else:
yield el

print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
``````