Pythonのsorted()やmax()などで引数keyを指定

Modified: | Tags: Python

Pythonの組み込み関数sorted()max(), min()、リストのsort()メソッドなどでは、引数keyに呼び出し可能オブジェクトを指定できる。これにより、要素に何らかの処理を行った結果を元にソートしたり最大値・最小値を取得したりできる。

operatorモジュールについては以下の記事も参照。

引数keyを使った具体例は以下の記事を参照。

組み込み関数を引数keyに指定

引数keyを使う簡単な例として、組み込み関数を指定する。

デフォルトでは、sorted()ではリストなどのイテラブルオブジェクトの要素がそのまま比較されてソートされる。

l = [1, -3, 2]

print(sorted(l))
# [-3, 1, 2]
source: key_usage.py

例えば、絶対値を返す組み込み関数abs()keyに指定すると、絶対値で比較されてソートされる。呼び出し可能オブジェクトを引数に指定する場合、括弧()はいらないので注意。

print(sorted(l, key=abs))
# [1, 2, -3]
source: key_usage.py

keyに指定する関数は比較の際に使われるのみで、結果の要素は元の値のまま。要素に関数を適用して変換したい場合はリスト内包表記を使う。

l_abs = [abs(i) for i in l]
print(l_abs)
# [1, 3, 2]

print(sorted(l_abs))
# [1, 2, 3]
source: key_usage.py

リストのsort()メソッドでも同様に引数keyを指定できる。

l.sort(key=abs)
print(l)
# [1, 2, -3]
source: key_usage.py

sort()は元のリストを更新する破壊的処理なので注意。昇順・降順を指定する引数reverseなど、sorted()sort()の詳細については以下の記事を参照。

最大値・最小値を返す組み込み関数max(), min()でも同様に引数keyを指定できる。

l = [1, -3, 2]

print(max(l))
# 2

print(max(l, key=abs))
# -3

print(min(l))
# -3

print(min(l, key=abs))
# 1
source: key_usage.py

なお、keyはキーワード専用引数なので、必ずkey=xxxという形で指定する必要がある。

以降、sorted()を例とするが、sort()max(), min()でも考え方は同じ。

別の例として、文字列のリストの場合を示す。デフォルトでは文字コード順(アルファベット順)にソートされるが、文字数を返す組み込み関数len()keyに指定することで、文字数の順番にソートできる。

l_str = ['bbb', 'c', 'aa']

print(sorted(l_str))
# ['aa', 'bbb', 'c']

print(sorted(l_str, key=len))
# ['c', 'aa', 'bbb']
source: key_usage.py

ラムダ式(無名関数)や自作の関数を引数keyに指定

組み込み関数だけでなくラムダ式(無名関数)やdefで定義した自作の関数を引数keyに指定することも可能。組み込み関数ではできない複雑な処理を適用できる。

リストを要素とするリスト(リストのリスト、二次元リスト)を例とする。

リスト同士を比較する場合、最初の等しくない要素が比較される(=先頭の要素から順に比較される)。

l_2d = [[2, 10], [1, -30], [-3, 20]]

print(sorted(l_2d))
# [[-3, 20], [1, -30], [2, 10]]
source: key_usage.py

ここで、最大値を返す組み込み関数max()を引数keyに指定すると、各リストの最大値を基準に並べ替えられる。

print(sorted(l_2d, key=max))
# [[1, -30], [2, 10], [-3, 20]]
source: key_usage.py

さらに、各リストの絶対値の最大値を基準にソートしたい場合、ラムダ式を使う。

print(sorted(l_2d, key=lambda x: max([abs(i) for i in x])))
# [[2, 10], [-3, 20], [1, -30]]
source: key_usage.py

なお、リストの要素数が少ない場合は特に気にする必要はないが、ここでmax()の引数としてリスト内包表記のジェネレータ版であるジェネレータ式を使うとメモリ使用量を抑えられる場合がある。

print(sorted(l_2d, key=lambda x: max(abs(i) for i in x)))
# [[2, 10], [-3, 20], [1, -30]]
source: key_usage.py

ラムダ式ではなくdefで関数を定義して引数keyに指定してもよい。

def max_abs(x):
    return max(abs(i) for i in x)

print(sorted(l_2d, key=max_abs))
# [[2, 10], [-3, 20], [1, -30]]
source: key_usage.py

operator.itemgetter()を引数keyに指定

標準ライブラリoperatorのitemgetter()はリストの要素や辞書の値を取得する呼び出し可能オブジェクトを返す。

リストに対するoperator.itemgetter()

以下のように、リストのリストを任意の位置(インデックス)の値に従ってソートできる。

import operator

l_2d = [[2, 10], [1, -30], [-3, 20]]

print(sorted(l_2d, key=operator.itemgetter(1)))
# [[1, -30], [2, 10], [-3, 20]]

同じ処理はラムダ式でも実現可能。

print(sorted(l_2d, key=lambda x: x[1]))
# [[1, -30], [2, 10], [-3, 20]]

ただし、operator.itemgetter()のほうがラムダ式より高速。ラムダ式とoperator.itemgetter()との簡単な処理速度比較の結果は最後に述べる。

辞書に対するoperator.itemgetter()

operator.itemgetter()は辞書dictに対しても使える。

共通のキーを持つ辞書のリストを例とする。辞書同士は比較できないのでデフォルトではエラーになるが、operator.itemgetter()keyに指定すると任意のキーの値を基準にソートできる。

l_dict = [{'k1': 2, 'k2': 10}, {'k1': 1}, {'k1': 3}]

# print(sorted(l_dict))
# TypeError: '<' not supported between instances of 'dict' and 'dict'

print(sorted(l_dict, key=operator.itemgetter('k1')))
# [{'k1': 1}, {'k1': 2, 'k2': 10}, {'k1': 3}]

指定したキーを持たない辞書オブジェクトが含まれているとエラーになるので注意。

# print(sorted(l_dict, key=operator.itemgetter('k2')))
# KeyError: 'k2'

同じ処理はラムダ式でも実現可能。

print(sorted(l_dict, key=lambda x: x['k1']))
# [{'k1': 1}, {'k1': 2, 'k2': 10}, {'k1': 3}]

ラムダ式で辞書のget()メソッドを使うと、指定したキーを持たない場合に任意の値に置き換えてソートできる。以下の記事を参照。

operator.itemgetter()に複数の引数を指定

operator.itemgetter()に複数の引数を指定すると、複数の値を含むタプルが返される。

l_dict = [{'k1': 2, 'k2': 'ccc'}, {'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}]

print(operator.itemgetter('k1', 'k2')(l_dict[0]))
# (2, 'ccc')

タプルの比較もリストの比較と同様、先頭の要素から順に比較される。したがって、複数の引数を指定したoperator.itemgetter()keyに指定すると、はじめの値が等しい場合は次の値が比較されてソートされる。

print(sorted(l_dict, key=operator.itemgetter('k1')))
# [{'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}]

print(sorted(l_dict, key=operator.itemgetter('k1', 'k2')))
# [{'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}, {'k1': 2, 'k2': 'ccc'}]

順番を変えると結果も変わるので注意。

print(sorted(l_dict, key=operator.itemgetter('k2', 'k1')))
# [{'k1': 2, 'k2': 'aaa'}, {'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'ccc'}]

これもラムダ式で実現可能。

print(sorted(l_dict, key=lambda x: (x['k1'], x['k2'])))
# [{'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}, {'k1': 2, 'k2': 'ccc'}]

上述のように処理速度ではoperator.itemgetter()のほうが有利だが、ラムダ式はoperatorをインポートする必要がないのでお手軽。特に処理速度をシビアに求める必要がなければラムダ式が使われることもある。以下のoperator.attrgetter(), operator.methodcaller()でも同じ。

operator.attrgetter()を引数keyに指定

operator.attrgetter()operator.itemgetter()の属性版。任意の属性を取得する呼び出し可能オブジェクトを返す。

日付を表すdatetime.dateオブジェクトのリストを例とする。datetime.dateday, month, year属性でそれぞれ日・月・年を取得できる。

import datetime

l_dt = [datetime.date(2003, 2, 10), datetime.date(2001, 3, 20), datetime.date(2002, 1, 30)]

print(l_dt[0])
# 2003-02-10

print(l_dt[0].day)
# 10

f = operator.attrgetter('day')
print(f(l_dt[0]))
# 10

デフォルトでは日付の順(時系列順)にソートされるが、operator.attrgetter()で任意の属性を基準にソートできる。

print(sorted(l_dt))
# [datetime.date(2001, 3, 20), datetime.date(2002, 1, 30), datetime.date(2003, 2, 10)]

print(sorted(l_dt, key=operator.attrgetter('day')))
# [datetime.date(2003, 2, 10), datetime.date(2001, 3, 20), datetime.date(2002, 1, 30)]

ラムダ式でも実現可能。処理速度はoperator.attrgetter()のほうが高速。

print(sorted(l_dt, key=lambda x: x.day))
# [datetime.date(2003, 2, 10), datetime.date(2001, 3, 20), datetime.date(2002, 1, 30)]

operator.methodcaller()を引数keyに指定

operator.methodcaller()は任意のメソッドを実行する呼び出し可能オブジェクトを返す。

文字列strの特定の文字列の位置を返すfind()メソッドを例とする。

l_str = ['0_xxxxA', '1_Axxxx', '2_xxAxx']

print(l_str[0])
# 0_xxxxA

print(l_str[0].find('A'))
# 6

f = operator.methodcaller('find', 'A')
print(f(l_str[0]))
# 6

デフォルトでは文字コードの順にソートされるが、operator.methodcaller()で任意のメソッドを実行した結果を元にソートできる。

print(sorted(l_str))
# ['0_xxxxA', '1_Axxxx', '2_xxAxx']

print(sorted(l_str, key=operator.methodcaller('find', 'A')))
# ['1_Axxxx', '2_xxAxx', '0_xxxxA']

ラムダ式でも実現可能。処理速度はoperator.methodcaller()のほうが高速。

print(sorted(l_str, key=lambda x: x.find('A')))
# ['1_Axxxx', '2_xxAxx', '0_xxxxA']

ラムダ式とoperator.itemgetter()との処理速度比較

ラムダ式とoperator.itemgetter()との簡易的な処理速度比較の結果を示す。

共通のキーを持つ辞書dictのリスト(要素数10000)を例とする。

import operator

l = [{'k1': i} for i in range(10000)]

print(len(l))
# 10000

print(l[:5])
# [{'k1': 0}, {'k1': 1}, {'k1': 2}, {'k1': 3}, {'k1': 4}]

print(l[-5:])
# [{'k1': 9995}, {'k1': 9996}, {'k1': 9997}, {'k1': 9998}, {'k1': 9999}]

Jupyter Notebookのマジックコマンド%%timeitで処理時間を計測する。

結果は以下の通り。便宜上、Pythonのスクリプトとして掲載するが、そのままPythonで実行しても計測されないので注意。

%%timeit
sorted(l, key=lambda x: x['k1'])
# 1.09 ms ± 35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sorted(l, key=operator.itemgetter('k1'))
# 716 µs ± 28.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sorted(l, key=lambda x: x['k1'], reverse=True)
# 1.11 ms ± 41.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sorted(l, key=operator.itemgetter('k1'), reverse=True)
# 768 µs ± 58.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
max(l, key=lambda x: x['k1'])
# 1.33 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
max(l, key=operator.itemgetter('k1'))
# 813 µs ± 54.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
min(l, key=lambda x: x['k1'])
# 1.27 ms ± 69.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
min(l, key=operator.itemgetter('k1'))
# 824 µs ± 83.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

上の結果では、sorted(), max(), min()いずれの関数でもラムダ式を100%とするとoperator.itemgetter()の処理時間は70%程度になっており、operator.itemgetter()のほうが高速であることが分かる。

当然ながら、数値は環境や処理の違い(要素数など)によって異なるので、上の結果はあくまでも参考とされたい。

関連カテゴリー

関連記事