Pythonでflatten(多次元リストを一次元に平坦化)
Pythonで多次元のリスト(リストのリスト、ネストしたリスト)を一次元に平坦化する方法について説明する。
- 2次元のリストを平坦化
itertools.chain.from_iterable()
sum()
- 処理速度の差
- 3次元以上のリストや不規則なリストを平坦化
NumPy配列ndarray
の場合はflatten()
またはravel()
を使う。
反対に、一次元のNumPy配列ndarray
やリストを二次元に変換する方法については以下の記事を参照。
2次元のリストを平坦化
itertools.chain.from_iterable()
リストを要素として持つリスト(2次元リスト)を平坦化する場合、標準ライブラリのitertoolsのitertools.chain.from_iterable()
を使う方法がある。
import itertools
l_2d = [[0, 1], [2, 3]]
print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]
itertools.chain.from_iterable()
はイテレータを返すので、リストに変換したい場合は上のサンプルコードのようにlist()
を使う。for文で使う場合はリスト化する必要はない。
タプルも同様に処理できる。ここでは結果をtuple()
でタプルにしている。リストにしたい場合はlist()
を使えばよい。
t_2d = ((0, 1), (2, 3))
print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)
itertools.chain.from_iterable()
で平坦化できるのは2次元の場合のみ。3次元以上の場合(ネストが深い場合)は以下のような結果となる。
l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
また、要素の中にイテラブルオブジェクトではないものが含まれている場合はエラーとなる。
l_mix = [[0, 1], [2, 3], 4]
# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable
3次元以上の場合や要素の型が不規則な場合については後述。
sum()
組み込み関数のsum()
を使う方法もある。
sum()
の第二引数には初期値を指定できる。ここに空のリスト[]
を指定すると、リストの+
演算によって、要素のリストが連結される。
第二引数のデフォルト値は0
なので、省略すると整数int
とリストの+
演算となってしまいエラーとなる。
print(sum(l_2d, []))
# [0, 1, 2, 3]
# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'
タプルでも同様に処理可能。
print(sum(t_2d, ()))
# (0, 1, 2, 3)
itertools.chain.from_iterable()
と同じように、3次元以上の場合や要素の型が不規則な場合はうまくいかない。
print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]
# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list
処理速度の差
sum()
は特に何かをインポートしなくてもよいので手軽だが、行数(内部のリストの数)が多い場合にはitertools.chain.from_iterable()
よりも遅いので注意。行数が多く、かつ、処理速度やメモリ効率が重要な場面ではitertools.chain.from_iterable()
を使うほうがよい。
以下のコードはJupyter Notebook上でマジックコマンド%%timeit
を使って計測したもの。Pythonスクリプトとして実行しても計測されないので注意。
5行の場合。
l_2d_5 = [[0, 1, 2]] * 5
print(l_2d_5)
# [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]
%%timeit
list(itertools.chain.from_iterable(l_2d_5))
# 711 ns ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%%timeit
sum(l_2d_5, [])
# 448 ns ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
100行。
l_2d_100 = [[0, 1, 2]] * 100
%%timeit
list(itertools.chain.from_iterable(l_2d_100))
# 7.27 µs ± 390 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
sum(l_2d_100, [])
# 41 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
10000行。
l_2d_10000 = [[0, 1, 2]] * 10000
%%timeit
list(itertools.chain.from_iterable(l_2d_10000))
# 513 µs ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
sum(l_2d_10000, [])
# 418 ms ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3次元以上のリストや不規則なリストを平坦化
3次元以上のリストや不規則なリストを平坦化するには関数を定義する。
以下を参考にした。
import collections
def flatten(l):
for el in l:
if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
yield from flatten(el)
else:
yield el
isinstance()
で要素el
の型をチェックして再帰的に処理している。
collections.abc.Iterable
でイテラブルかどうかを判断。標準ライブラリのcollectionsをインポートする必要がある。
- collections.abc.Iterable --- コレクションの抽象基底クラス — Python 3.7.3 ドキュメント
- 関連記事: Pythonのhasattr(), 抽象基底クラスABCによるダックタイピング
文字列str
やバイト列bytes
もイテラブルであるため除外している。除外しないと文字ごとに分解されてしまう。
この関数を使うと、あらゆる場合に対応できる。
print(list(flatten(l_2d)))
# [0, 1, 2, 3]
print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]
リストやタプル、range
など様々なイテラブルオブジェクトが含まれていても問題ない。
l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]
print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
対象をリストに限定すれば、collectionsをインポートしなくてもよい。タプルやrange
はそのままになってしまうが、多くの場合はこれで十分だろう。
def flatten_list(l):
for el in l:
if isinstance(el, list):
yield from flatten_list(el)
else:
yield el
print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]
print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]
print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]
print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]
isinstance()
の第二引数にはタプルで複数の型を指定できるので、必要な型のみ対象としてもよい。
def flatten_list_tuple_range(l):
for el in l:
if isinstance(el, (list, tuple, range)):
yield from flatten_list_tuple_range(el)
else:
yield el
print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]
もちろん、汎用的なのはcollections.abc.Iterable
を使う方法。お好みで。