How to Flatten a List of Lists in Python
In Python, you can flatten a list of lists (nested list, 2D list) using itertools.chain.from_iterable()
, sum()
, or list comprehensions.
Use ravel()
or flatten()
to flatten a NumPy array (numpy.ndarray
).
If you need to convert a one-dimensional ndarray
or list
to two dimensions, refer to the following article.
Flatten a list of lists (nested list, 2D list)
Flatten list with itertools.chain.from_iterable()
You can flatten a list of lists with itertools.chain.from_iterable()
.
import itertools
l_2d = [[0, 1], [2, 3]]
print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]
itertools.chain.from_iterable()
returns an iterator. To convert this iterator to a list, use list()
. If you use the result in a for
loop, you can use the iterator directly without converting it to a list.
This method can also be applied to a tuple of tuples. In the following example, the result is converted to a tuple with tuple()
. If you need a list, use list()
.
t_2d = ((0, 1), (2, 3))
print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)
Note that itertools.chain.from_iterable()
can only flatten 2D lists. For 3D or higher-dimensional lists, the result is as follows.
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
An error is raised if it contains non-iterable objects such as int
or float
.
l_mix = [[0, 1], [2, 3], 4]
# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable
Details about 3D or higher-dimensional lists and lists with non-iterable elements will be discussed later.
Flatten list with sum()
You can also flatten a list of lists with the built-in sum()
function.
You can specify an initial value as the second argument of sum()
. By passing an empty list []
, the +
operation concatenates the lists.
Note that the default value of the second argument is 0
. If omitted, an error is raised due to the +
operation between an int
and a list
.
l_2d = [[0, 1], [2, 3]]
print(sum(l_2d, []))
# [0, 1, 2, 3]
# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'
Tuples can be handled in the same way.
t_2d = ((0, 1), (2, 3))
print(sum(t_2d, ()))
# (0, 1, 2, 3)
Like itertools.chain.from_iterable()
, this method doesn't work for 3D or higher-dimensional lists and lists with non-iterable elements.
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
l_mix = [[0, 1], [2, 3], 4]
# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list
Flatten list with list comprehensions
You can also flatten a list of lists with nested list comprehensions.
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat = [x for row in matrix for x in row]
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
The above example is equivalent to the following nested for
loop.
flat = []
for row in matrix:
for x in row:
flat.append(x)
print(flat)
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
In the case of the list comprehension above, as with the other methods, only one level can be flattened. An error is raised if non-iterable objects are included.
You can make the nesting deeper to support more than three dimensions, or add conditional branching based on element types. However, this is not recommended because it would be too complicated.
See the following article for more information on list comprehensions.
Speed comparison
While sum()
is easy to use, it is significantly slower than itertools.chain.from_iterable()
or list comprehensions for large data (many inner lists). For performance-sensitive tasks, it is better to avoid sum()
.
Although you have to import itertools, itertools.chain.from_iterable()
is faster than list comprehensions.
The following examples use the Jupyter Notebook magic command %%timeit
. Note that these will not work if run as Python scripts.
5 lines:
l_2d_5 = [[0, 1, 2] for i in range(5)]
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]
%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 537 ns ± 4.59 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%%timeit
sum(l_2d_5, [])
# 319 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%%timeit
[x for row in l_2d_5 for x in row]
# 764 ns ± 32.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
100 lines:
l_2d_100 = [[0, 1, 2] for i in range(100)]
%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 6.94 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
sum(l_2d_100, [])
# 35.5 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
[x for row in l_2d_100 for x in row]
# 13.5 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10000 lines:
l_2d_10000 = [[0, 1, 2] for i in range(10000)]
%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 552 µs ± 79.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
sum(l_2d_10000, [])
# 343 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
[x for row in l_2d_10000 for x in row]
# 1.11 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Flatten 3D or higher-dimensional lists and lists with non-iterable elements
You need to define a custom function to flatten 3D or higher-dimensional lists or lists with non-iterable elements.
The following sample code is based on the Stack Overflow post.
import collections
def flatten(l):
for el in l:
if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
yield from flatten(el)
else:
yield el
The function checks the type of element el
using isinstance()
and applies the operation recursively.
To determine if el
is iterable, the function uses collections.abc.Iterable
. This requires importing the collections
module.
- collections.abc.Iterable — Abstract Base Classes for Containers — Python 3.8.4 documentation
- Duck typing with hasattr() and abstract base class in Python
As strings (str
) and byte strings (bytes
) are iterable, they are excluded from this operation. Without this exclusion, these types would be separated into individual characters.
This function can handle all cases.
l_2d = [[0, 1], [2, 3]]
print(list(flatten(l_2d)))
# [0, 1, 2, 3]
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
l_mix = [[0, 1], [2, 3], 4]
print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]
The function can handle a variety of iterable objects, including lists, tuples, and ranges.
l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]
print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
If you're only dealing with lists, there's no need to import collections
. The function won't flatten tuples or ranges, but this should suffice for most use cases.
def flatten_list(l):
for el in l:
if isinstance(el, list):
yield from flatten_list(el)
else:
yield el
print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]
print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]
print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]
You can specify multiple types using a tuple as the second argument of isinstance()
.
def flatten_list_tuple_range(l):
for el in l:
if isinstance(el, (list, tuple, range)):
yield from flatten_list_tuple_range(el)
else:
yield el
print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]