# How to flatten a list of lists in Python

Posted: 2020-07-21 / Tags: Python, List

To flatten a list of lists (a multi-dimensional list, nested list) in Python, you can use `itertools.chain.from_iterable()`, `sum()`, etc.

• Flatten 2D list (a list of lists)
• `itertools.chain.from_iterable()`
• `sum()`
• Speed
• Flatten 3D or higher lists and irregular lists

## Flatten 2D list (a list of lists)

### itertools.chain.from_iterable()

A list of lists (2D list) can be flattened using `itertools.chain.from_iterable()`.

``````import itertools

l_2d = [[0, 1], [2, 3]]

print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]
``````

`itertools.chain.from_iterable()` returns an iterator, so if you want to convert it to a list, use `list()`. It is not necessary to make a list when using it in a `for` statement.

A tuple of tuples can be handled in the same way. Here, the result is converted to a tuple by using `tuple()`. If you need a list, use `list()`.

``````t_2d = ((0, 1), (2, 3))

print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)
``````

Only 2D can be flattened with `itertools.chain.from_iterable()`. In case of 3D or higher, the result is as follows.

``````l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
``````

An error occurs if it contains something that is not an iterable object such as `int` or `float`.

``````l_mix = [[0, 1], [2, 3], 4]

# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable
``````

Cases with 3D or higher and irregular list will be described later.

### sum()

You can also use the built-in function `sum()` to flatten.

An initial value can be specified as the second argument of `sum()`. If you pass the empty list `[]`, the list's `+` operation will concatenate lists.

Note that the default value of the second argument is `0`, so if omitted, an error will occur due to `+` operation with `int` and `list`.

``````print(sum(l_2d, []))
# [0, 1, 2, 3]

# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'
``````

Tuples can be handled in the same way.

``````print(sum(t_2d, ()))
# (0, 1, 2, 3)
``````

Like `itertools.chain.from_iterable()`, it doesn't work for more than 3D or irregular list.

``````print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list
``````

### Speed

`sum()` is convenient because you don't have to import anything, but be careful that `sum()` is slower than `itertools.chain.from_iterable()` when the number of lines (the number of internal lists) is large.

It is better to use `itertools.chain.from_iterable()` when the number of lines is large and the processing speed and memory efficiency are important.

The following code is measured using the magic command `%%timeit` on Jupyter Notebook. Note that it dosen't work on Python script.

5 lines:

``````l_2d_5 = [[0, 1, 2]] * 5
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]

%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 711 ns ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
sum(l_2d_5, [])
# 448 ns ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
``````

100 lines:

``````l_2d_100 = [[0, 1, 2]] * 100

%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 7.27 µs ± 390 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%%timeit
sum(l_2d_100, [])
# 41 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
``````

10000 lines:

``````l_2d_10000 = [[0, 1, 2]] * 10000

%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 513 µs ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sum(l_2d_10000, [])
# 418 ms ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
``````

## Flatten 3D or higher lists and irregular lists

It is necessary to define a new function to flatten 3D or higher lists or irregular lists.

I referred to the following link.

``````import collections

def flatten(l):
for el in l:
if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
yield from flatten(el)
else:
yield el
``````

The type of element `el` is checked by `isinstance()` and processed recursively.

Determine if `el` is iterable by using `collections.abc.Iterable`. You need to import the standard library `collections`.

The string `str` and the byte string `bytes` are also iterable, so they are excluded. If not excluded, it will be separated for each character.

This function can be used in all cases.

``````print(list(flatten(l_2d)))
# [0, 1, 2, 3]

print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]
``````

It doesn't matter if various iterable objects such as lists, tuples, and `range` are included.

``````l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]

print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
``````

If you only want to handle list, you don't need to import `collections`. Tuples and `range` are not flattened, but in most cases this will be sufficient.

``````def flatten_list(l):
for el in l:
if isinstance(el, list):
yield from flatten_list(el)
else:
yield el

print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]

print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]

print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]
``````

You can specify multiple types by tuple in the second argument of `isinstance()`.

``````def flatten_list_tuple_range(l):
for el in l:
if isinstance(el, (list, tuple, range)):
yield from flatten_list_tuple_range(el)
else:
yield el

print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
``````