note.nkmk.me

Pythonでflatten(多次元リストを一次元に平坦化)

Date: 2019-06-01 / tags: Python, リスト

Pythonで多次元のリスト(リストのリスト、ネストしたリスト)を一次元に平坦化する方法について説明する。

  • 2次元のリストを平坦化
    • itertools.chain.from_iterable()
    • sum()
  • 3次元以上のリストや不規則なリストを平坦化

NumPy配列ndarrayの場合はflatten()またはravel()を使う。

スポンサーリンク

2次元のリストを平坦化

itertools.chain.from_iterable()

リストを要素として持つ2次元のリストを平坦化する場合、標準ライブラリのitertoolsのitertools.chain.from_iterable()を使う方法がある。

import itertools

l_2d = [[0, 1], [2, 3]]

print(list(itertools.chain.from_iterable(l_2d)))
# [0, 1, 2, 3]

itertools.chain.from_iterable()はイテレータを返すので、リストに変換したい場合は上のサンプルコードのようにlist()を使う。for文で使う場合はリスト化する必要はない。

タプルも同様に処理できる。ここでは結果をtuple()でタプルにしている。リストにしたい場合はlist()を使えばよい。

t_2d = ((0, 1), (2, 3))

print(tuple(itertools.chain.from_iterable(t_2d)))
# (0, 1, 2, 3)

itertools.chain.from_iterable()で平坦化できるのは2次元の場合のみ。3次元以上の場合(ネストが深い場合)は以下のような結果となる。

l_3d = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]

print(list(itertools.chain.from_iterable(l_3d)))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

また、要素の中にイテラブルオブジェクトではないものが含まれている場合はエラーとなる。

l_mix = [[0, 1], [2, 3], 4]

# print(list(itertools.chain.from_iterable(l_mix)))
# TypeError: 'int' object is not iterable

3次元以上の場合や要素の型が不規則な場合については後述。

sum()

組み込み関数のsum()を使う方法もある。

sum()の第二引数には初期値を指定できる。ここに空のリスト[]を指定すると、リストの+演算によって、要素のリストが連結される。

第二引数のデフォルト値は0なので、省略すると整数intとリストの+演算となってしまいエラーとなる。

print(sum(l_2d, []))
# [0, 1, 2, 3]

# print(sum(l_2d))
# TypeError: unsupported operand type(s) for +: 'int' and 'list'

タプルでも同様に処理可能。

print(sum(t_2d, ()))
# (0, 1, 2, 3)

itertools.chain.from_iterable()と同じように、3次元以上の場合や要素の型が不規則な場合はうまくいかない。

print(sum(l_3d, []))
# [[0, 1], [2, 3], [4, 5], [6, 7]]

# print(sum(l_mix, []))
# TypeError: can only concatenate list (not "int") to list

sum()は特に何かをインポートしなくてもよいので手軽だが、itertools.chain.from_iterable()よりも遅いので注意。処理速度が重要な場面ではitertools.chain.from_iterable()を使うほうがよい。

3次元以上のリストや不規則なリストを平坦化

3次元以上のリストや不規則なリストを平坦化するには関数を定義する。

以下を参考にした。

import collections

def flatten(l):
    for el in l:
        if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

isinstance()で要素elの型をチェックして再帰的に処理している。

collections.abc.Iterableでイテラブルかどうかを判断。標準ライブラリのcollectionsをインポートする必要がある。

文字列strやバイト列bytesもイテラブルであるため除外している。除外しないと文字ごとに分解されてしまう。

この関数を使うと、あらゆる場合に対応できる。

print(list(flatten(l_2d)))
# [0, 1, 2, 3]

print(list(flatten(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten(l_mix)))
# [0, 1, 2, 3, 4]

リストやタプル、rangeなど様々なイテラブルオブジェクトが含まれていても問題ない。

l_t_r_mix = [[0, 1], (2, 3), 4, range(5, 8)]

print(list(flatten(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]

対象をリストに限定すれば、collectionsをインポートしなくてもよい。タプルやrangeはそのままになってしまうが、多くの場合はこれで十分だろう。

def flatten_list(l):
    for el in l:
        if isinstance(el, list):
            yield from flatten(el)
        else:
            yield el

print(list(flatten_list(l_2d)))
# [0, 1, 2, 3]

print(list(flatten_list(l_3d)))
# [0, 1, 2, 3, 4, 5, 6, 7]

print(list(flatten_list(l_mix)))
# [0, 1, 2, 3, 4]

print(list(flatten_list(l_t_r_mix)))
# [0, 1, (2, 3), 4, range(5, 8)]

isinstance()の第二引数にはタプルで複数の型を指定できるので、必要な型のみ対象としてもよい。

def flatten_list_tuple_range(l):
    for el in l:
        if isinstance(el, (list, tuple, range)):
            yield from flatten(el)
        else:
            yield el

print(list(flatten_list_tuple_range(l_t_r_mix)))
# [0, 1, 2, 3, 4, 5, 6, 7]

もちろん、汎用的なのはcollections.abc.Iterableを使う方法。お好みで。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事