note.nkmk.me

How to flatten a list of lists in Python

Posted: 2020-07-21 / Tags: Python, List

To flatten a list of lists (a multi-dimensional list, nested list) in Python, you can use itertools.chain.from_iterable(), sum(), etc.

  • Flatten 2D list (a list of lists)
    • itertools.chain.from_iterable()
    • sum()
    • Speed
  • Flatten 3D or higher lists and irregular lists
Sponsored Link

Flatten 2D list (a list of lists)

itertools.chain.from_iterable()

A list of lists (2D list) can be flattened using itertools.chain.from_iterable().

import itertools

l_2d = [[0, 1], [2, 3]]

print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]

itertools.chain.from_iterable() returns an iterator, so if you want to convert it to a list, use list(). It is not necessary to make a list when using it in a for statement.

A tuple of tuples can be handled in the same way. Here, the result is converted to a tuple by using tuple(). If you need a list, use list().

t_2d = ((0, 1), (2, 3))

print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)

Only 2D can be flattened with itertools.chain.from_iterable(). In case of 3D or higher, the result is as follows.

l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

An error occurs if it contains something that is not an iterable object such as int or float.

l_mix = [[0, 1], [2, 3], 4]

# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable

Cases with 3D or higher and irregular list will be described later.

sum()

You can also use the built-in function sum() to flatten.

An initial value can be specified as the second argument of sum(). If you pass the empty list [], the list's + operation will concatenate lists.

Note that the default value of the second argument is 0, so if omitted, an error will occur due to + operation with int and list.

print(sum(l_2d, []))
# [0, 1, 2, 3]

# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'

Tuples can be handled in the same way.

print(sum(t_2d, ()))
# (0, 1, 2, 3)

Like itertools.chain.from_iterable(), it doesn't work for more than 3D or irregular list.

print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list

Speed

sum() is convenient because you don't have to import anything, but be careful that sum() is slower than itertools.chain.from_iterable() when the number of lines (the number of internal lists) is large.

It is better to use itertools.chain.from_iterable() when the number of lines is large and the processing speed and memory efficiency are important.

The following code is measured using the magic command %%timeit on Jupyter Notebook. Note that it dosen't work on Python script.

5 lines:

l_2d_5 = [[0, 1, 2]] * 5
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]

%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 711 ns ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%%timeit
sum(l_2d_5, [])
# 448 ns ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

100 lines:

l_2d_100 = [[0, 1, 2]] * 100

%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 7.27 µs ± 390 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%%timeit
sum(l_2d_100, [])
# 41 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

10000 lines:

l_2d_10000 = [[0, 1, 2]] * 10000

%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 513 µs ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sum(l_2d_10000, [])
# 418 ms ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Flatten 3D or higher lists and irregular lists

It is necessary to define a new function to flatten 3D or higher lists or irregular lists.

I referred to the following link.

import collections

def flatten(l):
    for el in l:
        if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

The type of element el is checked by isinstance() and processed recursively.

Determine if el is iterable by using collections.abc.Iterable. You need to import the standard library collections.

The string str and the byte string bytes are also iterable, so they are excluded. If not excluded, it will be separated for each character.

This function can be used in all cases.

print(list(flatten(l_2d)))
# [0, 1, 2, 3]

print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]

It doesn't matter if various iterable objects such as lists, tuples, and range are included.

l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]

print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]

If you only want to handle list, you don't need to import collections. Tuples and range are not flattened, but in most cases this will be sufficient.

def flatten_list(l):
    for el in l:
        if isinstance(el, list):
            yield from flatten_list(el)
        else:
            yield el

print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]

print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]

print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]

You can specify multiple types by tuple in the second argument of isinstance().

def flatten_list_tuple_range(l):
    for el in l:
        if isinstance(el, (list, tuple, range)):
            yield from flatten_list_tuple_range(el)
        else:
            yield el

print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
Sponsored Link
Share

Related Categories

Related Posts