PythonでのUnion-Find(素集合データ構造)の実装と使い方

Modified: | Tags: Python, アルゴリズム

PythonでのUnion Findデータ構造(素集合データ構造)の実装とその使い方を説明する。

まずはじめに最終形のクラスとその使い方のサンプルコードを示し、後半で素朴な実装からいくつかの工夫を加えて最終形に至るまでを説明する。

  • Union Find(素集合データ構造)の概要
  • PythonでのUnion Findの実装例
  • サンプルコードのクラスの使い方
  • 文字列やタプルなどを要素とする場合
  • 素朴な実装
  • 実装の効率化
    • 経路圧縮(Path Compression)
    • ランク(Union by Rank)
    • サイズ(Union by Size)
    • 根の値にサイズ、ランクを格納

AtCoderおよびAOJのUnion Findに関する問題の解答コードを参考にした。

Union Find(素集合データ構造)の概要

要素を素集合(互いに重ならない集合)に分割して管理するデータ構造を素集合データ構造(またはUnion-Findデータ構造、Merge-Find集合)と呼ぶ。このデータ構造は以下の2つの操作(Union-Findアルゴリズム)をサポートする。集合の分割はできない。

  • Union: 2つの集合を1つに併合する
  • Find: ある要素がどの集合に属しているかを判定する
    • 2つの要素が同じ集合に属しているかの判定も可能

データ構造もその操作もUnion Findと呼ぶ(らしい)。

概要を掴むのには以下のスライドの説明が分かりやすかった。

Union-Find木と呼ばれることもあるようだが、各集合(グループ)が木で表されるので、複数の集合がある場合は全体では森となる。

PythonでのUnion Findの実装例

Union, Findの操作をサポートするUnion FIndデータ構造のPythonでの実装例は以下の通り。あくまでも例であり、より効率的な実装があるかもしれない。

from collections import defaultdict


class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

n個の要素を0 ~ n - 1の番号で管理する。以下の属性およびメソッドを持つ。

  • parents
    • 各要素の親要素の番号を格納するリスト
    • 要素が根(ルート)の場合は-(そのグループの要素数)を格納する
  • find(x)
    • 要素xが属するグループの根を返す
  • union(x, y)
    • 要素xが属するグループと要素yが属するグループとを併合する
  • size(x)
    • 要素xが属するグループのサイズ(要素数)を返す
  • same(x, y)
    • 要素x, yが同じグループに属するかどうかを返す
  • members(x)
    • 要素xが属するグループに属する要素をリストで返す
    • その後に行う処理によっては集合内包表記やジェネレータ式を使うほうが効率的かもしれない(下のroots()も同じ)
  • roots()
    • すべての根の要素をリストで返す
  • group_count()
    • グループの数を返す
  • all_group_members
  • __str__()

なぜこのようなコードになっているかについては後述。素朴な実装から順を追って説明する。

なお、上述のように、Union Findの主要な操作はUnionFindのみ。上の例の場合、size()以降は便宜的に加えたメソッドなので必要なければ省略しても問題ない。

union()unite()find()root()same()find()といった名前で実装されている例も多い模様。比較する際は適宜読み替えていただきたい。

サンプルコードのクラスの使い方

上で示したサンプルコードのクラスを実際に使ってみる。

要素数6個で初期化。はじめはすべての要素が根(ルート)となり別々のグループとなる。

uf = UnionFind(6)
print(uf.parents)
# [-1, -1, -1, -1, -1, -1]

print(uf)
# 0: [0]
# 1: [1]
# 2: [2]
# 3: [3]
# 4: [4]
# 5: [5]

union()でグループを併合。

uf.union(0, 2)
print(uf.parents)
# [-2, -1, 0, -1, -1, -1]

print(uf)
# 0: [0, 2]
# 1: [1]
# 3: [3]
# 4: [4]
# 5: [5]

さらにグループを併合していく。

uf.union(1, 3)
print(uf.parents)
uf.union(4, 5)
print(uf.parents)
uf.union(1, 4)
print(uf.parents)
# [-2, -2, 0, 1, -1, -1]
# [-2, -2, 0, 1, -2, 4]
# [-2, -4, 0, 1, 1, 4]

print(uf)
# 0: [0, 2]
# 1: [1, 3, 4, 5]

ここで注目すべきはparents[5]の値が4であること。5は根が1のグループに属しているがparentsに格納された親要素は4parentsに格納されているのはあくまでも親要素であり根要素であるとは限らないことに注意。

find()が実行されると経路圧縮(後述)により親要素が根要素に更新される。

print(uf.parents)
# [-2, -4, 0, 1, 1, 1]

例の実装では__str__()内ですべての要素に対してfind()が実行されるので、print(uf)によって__str__()が実行された段階で経路圧縮が行われている。all_group_members()メソッドも同様。

例えば処理の理解のためにこのクラスでparentsの遷移を確認するような場合は、print(uf)all_group_members()を途中で使ってしまうとその時点で経路圧縮が行われるので注意。グループ分けの結果は変わらないのでparentsの中身を気にしない(結果だけを見たい)のであれば問題ない。

最終状態で各メソッドを実行すると以下のようになる。

print(uf.find(0))
# 0

print(uf.find(5))
# 1

print(uf.size(0))
# 2

print(uf.size(5))
# 4

print(uf.same(0, 2))
# True

print(uf.same(0, 5))
# False

print(uf.members(0))
# [0, 2]

print(uf.members(5))
# [1, 3, 4, 5]

print(uf.roots())
# [0, 1]

print(uf.group_count())
# 2

print(uf.all_group_members())
# {0: [0, 2], 1: [1, 3, 4, 5]}

print(list(uf.all_group_members().values()))
# [[0, 2], [1, 3, 4, 5]]

最後に示したように、グループごとの要素のリストのリストを取得したい場合はall_group_members()で取得したdefaultdictvalues()メソッドを適用する。

文字列やタプルなどを要素とする場合

上の例のクラスでは要素の個数nを指定し、n個の要素を0 ~ n - 1の番号で管理する。find()union()などの引数にはその番号を指定する。

例えば人の名前や都市の名前など何らかの文字列を元にUnion Findデータ構造を構成して処理したい場合は、文字列と番号のペアの辞書を用意すると便利。

{文字列: 番号, ... }{番号: 文字列, ...}の辞書をそれぞれ用意する。ここでは辞書内包表記を使う。

l = ['A', 'B', 'C', 'D', 'E']

d = {x: i for i, x in enumerate(l)}
print(d)
# {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4}

d_inv = {i: x for i, x in enumerate(l)}
print(d_inv)
# {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'}

Union Find自体は0 ~ n - 1の番号で管理するが、2つの辞書を適宜使用することでfind()union()などの引数を文字列で指定したり、出力を文字列に変換したりできる。

uf_s = UnionFind(len(l))
print(uf_s)
# 0: [0]
# 1: [1]
# 2: [2]
# 3: [3]
# 4: [4]

uf_s.union(d['A'], d['D'])
uf_s.union(d['D'], d['C'])
uf_s.union(d['E'], d['B'])
print(uf_s)
# 0: [0, 2, 3]
# 4: [1, 4]

print(d_inv[uf_s.find(d['D'])])
# A

print(uf_s.size(d['D']))
# 3

print(uf_s.same(d['A'], d['D']))
# True

print([d_inv[i] for i in uf_s.members(d['D'])])
# ['A', 'C', 'D']

print([d_inv[i] for i in uf_s.roots()])
# ['A', 'E']

print(uf_s.group_count())
# 2

元のクラスを継承した以下のようなクラスを定義してもよい。__init__()で2つの辞書を生成して入出力時に使う。

class UnionFindLabel(UnionFind):
    def __init__(self, labels):
        assert len(labels) == len(set(labels))

        self.n = len(labels)
        self.parents = [-1] * self.n
        self.d = {x: i for i, x in enumerate(labels)}
        self.d_inv = {i: x for i, x in enumerate(labels)}

    def find_label(self, x):
        return self.d_inv[super().find(self.d[x])]

    def union(self, x, y):
        super().union(self.d[x], self.d[y])

    def size(self, x):
        return super().size(self.d[x])

    def same(self, x, y):
        return super().same(self.d[x], self.d[y])

    def members(self, x):
        root = self.find(self.d[x])
        return [self.d_inv[i] for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [self.d_inv[i] for i, x in enumerate(self.parents) if x < 0]

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.d_inv[self.find(member)]].append(self.d_inv[member])
        return group_members

ここで、find()は他のメソッド内で使われており元のメソッドを上書きすると面倒なのでfind_label()という別名のメソッドを定義している。外から根の要素を取得したい場合はfind_label()を使う。

使用例は以下の通り。

l = ['A', 'B', 'C', 'D', 'E']

ufl = UnionFindLabel(l)
print(ufl)
# A: ['A']
# B: ['B']
# C: ['C']
# D: ['D']
# E: ['E']

ufl.union('A', 'D')
ufl.union('D', 'C')
ufl.union('E', 'B')
print(ufl)
# A: ['A', 'C', 'D']
# E: ['B', 'E']

print(ufl.find_label('D'))
# A

print(ufl.size('D'))
# 3

print(ufl.same('A', 'D'))
# True

print(ufl.members('D'))
# ['A', 'C', 'D']

print(ufl.roots())
# ['A', 'E']

print(ufl.group_count())
# 2

print(ufl.all_group_members())
# {'A': ['A', 'C', 'D'], 'E': ['B', 'E']}

文字列に限らず、0始まりの連番以外の数値を要素のラベルとしたい場合などにも利用可能。

ufl_n = UnionFindLabel([1, 2, 3, 4, 5])
print(ufl_n)
# 1: [1]
# 2: [2]
# 3: [3]
# 4: [4]
# 5: [5]

ufl_n.union(1, 4)
ufl_n.union(4, 3)
ufl_n.union(5, 2)
print(ufl_n)
# 1: [1, 3, 4]
# 5: [2, 5]

上の例の場合は元のクラスをn + 1で初期化してもよい。0の要素は使われないのでグループの数は常にひとつ多くなることに注意。

ufl_n2 = UnionFind(6)
print(ufl_n2)
# 0: [0]
# 1: [1]
# 2: [2]
# 3: [3]
# 4: [4]
# 5: [5]

ufl_n2.union(1, 4)
ufl_n2.union(4, 3)
ufl_n2.union(5, 2)
print(ufl_n2)
# 0: [0]
# 1: [1, 3, 4]
# 5: [2, 5]

文字列や数値の他にも、辞書のキーとして使えるオブジェクト(hashableなオブジェクト)であれば何でもいい。タプルでもOK。リストはダメ。

座標を管理したい場合などにはタプルを使うと便利。

ufl_t = UnionFindLabel([(0, 0), (0, 1), (1, 0), (1,1)])
print(ufl_t)
# (0, 0): [(0, 0)]
# (0, 1): [(0, 1)]
# (1, 0): [(1, 0)]
# (1, 1): [(1, 1)]

ufl_t.union((0, 1), (1, 0))
ufl_t.union((0, 0), (1, 0))
print(ufl_t)
# (0, 1): [(0, 0), (0, 1), (1, 0)]
# (1, 1): [(1, 1)]

print(ufl_t.same((0, 1), (0, 0)))
# True

素朴な実装

ここからは素朴な実装からいくつかの工夫を加えて最終型に至るまでを説明する。冒頭に紹介した以下のスライドを適宜参照する。

まず素朴な実装の例は以下の通り。

class UnionFindBasic():
    def __init__(self, n):
        self.parents = list(range(n))

    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            return self.find(self.parents[x])

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        self.parents[y] = x

parentsには各要素の親要素の番号を格納する。自分自身の番号と親番号が一致しているとその要素は根であると判定できる。

使用例は以下の通り。

初期状態はすべての要素が根(parents[i] == i)。

ufb = UnionFindBasic(5)
print(ufb.parents)
# [0, 1, 2, 3, 4]

union()では単純に第二引数のグループの根の親を第一引数のグループの親に変更しているだけ。

ufb.union(3, 4)
print(ufb.parents)
ufb.union(2, 3)
print(ufb.parents)
ufb.union(1, 2)
print(ufb.parents)
ufb.union(0, 4)
print(ufb.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 2, 2, 3]
# [0, 1, 1, 2, 3]
# [0, 0, 1, 2, 3]

print([ufb.find(i) for i in range(5)])
# [0, 0, 0, 0, 0]

最終的にすべての根が0となり同じグループに属しているが、スライド10ページのような縦長のツリーになってしまっている。

この場合、例えばfind(4)とした場合、3 -> 2 -> 1と辿らないと根を取得できない。この程度のサイズであれば問題ないが、さらに要素数が増えると処理に時間がかかってしまい非効率。

実装の効率化

効率的な実装のためにいくつかの工夫がある。

経路圧縮(Path Compression)

1つ目の工夫は経路圧縮(Path Compression)。スライド11ページ。

find()で根を調べる際に、調べた要素の親を根に変更しつなぎ直す。

経路圧縮を加えた実装は以下の通り。find()で根番号を取得する際に、親番号を根番号に更新する。

class UnionFindPathCompression():
    def __init__(self, n):
        self.parents = list(range(n))

    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        self.parents[y] = x

素朴な実装の例と同じようにunion()で併合していくと以下のような結果となる。

ufpc = UnionFindPathCompression(5)
print(ufpc.parents)
# [0, 1, 2, 3, 4]

ufpc.union(3, 4)
print(ufpc.parents)
ufpc.union(2, 3)
print(ufpc.parents)
ufpc.union(1, 2)
print(ufpc.parents)
ufpc.union(0, 4)
print(ufpc.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 2, 2, 3]
# [0, 1, 1, 2, 3]
# [0, 0, 1, 1, 1]

print([ufpc.find(i) for i in range(5)])
# [0, 0, 0, 0, 0]

最後のunion(0, 4)の中で呼ばれるfind(4)によって親番号(parentsの値)が更新され、深い形状が解消される。

当然、find()が呼ばれない限り経路圧縮は行われないが、平均すると計算量が削減できる。

ランク(Union by Rank)

2つ目の工夫はランク(Union by Rank)。スライド11ページ。

ランク(木の高さ)の情報を保持しておき、併合する際に低い方を高い方につなげる(低い方の根の親を高い方の根にする)。

経路圧縮にランクを加えた実装は以下の通り。ランク情報を格納するリストであるrank属性を追加する。[0] * nで初期化する。

union()でランクを元に併合する。ランクが同じグループを併合する場合は親(根が変わらない方)のランクを1増やす。

class UnionFindByRank():
    def __init__(self, n):
        self.parents = list(range(n))
        self.rank = [0] * n

    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.rank[x] < self.rank[y]:
            self.parents[x] = y
        else:
            self.parents[y] = x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1

これまでの例と同様に併合していくと、以下のようにすべての子要素が根に接続された状態になる。経路圧縮のみのときよりも効率的。

ufbr = UnionFindByRank(5)
print(ufbr.parents)
# [0, 1, 2, 3, 4]

ufbr.union(3, 4)
print(ufbr.parents)
ufbr.union(2, 3)
print(ufbr.parents)
ufbr.union(1, 2)
print(ufbr.parents)
ufbr.union(0, 4)
print(ufbr.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 3, 3, 3]
# [0, 3, 3, 3, 3]
# [3, 3, 3, 3, 3]

なお、これまでの例とは根の要素が変わっている。Union-Findデータ構造では、同じUnion操作を行った場合、どの要素が同じグループに属するかは一意に定まるが、どの要素が根になるかは処理の内容に依存する。

例えば初期状態から併合するときにどちらを根としても良いように、どの要素が根であるかは特に意味はない。

サイズ(Union by Size)

スライドには記載されていないが、ランクではなくグループのサイズ(要素数)を元に併合時のつなげ方を決定する方法がある。併合する際にサイズが小さい方を大きい方につなげる。

実装は以下の通り。考え方はランクの場合と同じ。サイズの大きい方のsizeに小さい方(子になる方)のsizeを加えて更新していく。

class UnionFindBySize():
    def __init__(self, n):
        self.parents = list(range(n))
        self.size = [1] * n

    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.size[x] < self.size[y]:
            self.size[y] += self.size[x]
            self.parents[x] = y
        else:
            self.size[x] += self.size[y]
            self.parents[y] = x

#         if self.size[x] < self.size[y]:
#             x, y = y, x

#         self.size[x] += self.size[y]
#         self.parents[y] = x

場合分けはコメントアウトした部分のように書くこともできる。xのサイズが小さい場合はxyをスワップして、常にxの方のサイズが大きいものとして処理する。

結果は以下の通り。ランクの場合と同じ。

ufbs = UnionFindBySize(5)
print(ufbs.parents)
# [0, 1, 2, 3, 4]

ufbs.union(3, 4)
print(ufbs.parents)
ufbs.union(2, 3)
print(ufbs.parents)
ufbs.union(1, 2)
print(ufbs.parents)
ufbs.union(0, 4)
print(ufbs.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 3, 3, 3]
# [0, 3, 3, 3, 3]
# [3, 3, 3, 3, 3]

サイズが小さい方(子になる方)のsizeはケアしていないので、根以外のsizeの値には意味はない。find()で根を取得するとそのグループのサイズが取得できる。

print(ufbs.size)
# [1, 1, 1, 5, 1]

print(ufbs.size[ufbs.find(0)])
# 5

根の値にサイズ、ランクを格納

AtCoderおよびAOJのいくつかの解答コードで使われていたテクニックが、根の親の値にサイズの情報を持たせるというもの。誰が考えたのか知らないが賢い。感動した。

これまでの例では根のparentにはその要素自身の値(番号)を格納していたが、そこに-(そのグループの要素数)(マイナスの要素数)を格納する。

正負で根かそうでないかを判断し、負の場合(根の場合)はその絶対値でサイズ(要素数)を取得できる。サイズではなくランクを格納してもよい。

これにより、リストとして保持していたsizerankが必要なくなる。

サイズの場合の実装例は以下の通り。冒頭で示した最終形と同じ。スワップを使っている。

class UnionFind():
    def __init__(self, n):
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

結果は以下の通り。

uf = UnionFind(5)
print(uf.parents)
# [-1, -1, -1, -1, -1]

uf.union(3, 4)
print(uf.parents)
uf.union(2, 3)
print(uf.parents)
uf.union(1, 2)
print(uf.parents)
uf.union(0, 4)
print(uf.parents)
# [-1, -1, -1, -2, 3]
# [-1, -1, 3, -3, 3]
# [-1, 3, 3, -4, 3]
# [3, 3, 3, -5, 3]

関連カテゴリー

関連記事