How to Use the key Argument in Python (sorted, max, etc.)

Modified: | Tags: Python

In Python, you can specify a function or callable for the key argument in the built-in functions such as sorted(), max(), min(), etc.

See the following article for more information on the operator module.

See also the following article for examples of using the key argument.

Built-in functions for the key argument

A simple example of using the key argument is to specify a built-in function.

By default, sorted() compares and sorts the list elements as they are.

l = [1, -3, 2]

print(sorted(l))
# [-3, 1, 2]
source: key_usage.py

For example, if you specify the abs() function, which returns an absolute value, for the key argument, the elements will be sorted by the absolute value of each element.

Note that you shouldn't include the parentheses () when specifying a function or other callable as an argument.

print(sorted(l, key=abs))
# [1, 2, -3]
source: key_usage.py

The function specified for the key argument is used only for comparison; the actual elements in the result remain unchanged. If you want to modify list elements directly using a function, consider employing list comprehensions.

l_abs = [abs(i) for i in l]
print(l_abs)
# [1, 3, 2]

print(sorted(l_abs))
# [1, 2, 3]
source: key_usage.py

The same applies to the sort() method of lists.

l.sort(key=abs)
print(l)
# [1, 2, -3]
source: key_usage.py

You can also specify the key argument in max() and min().

l = [1, -3, 2]

print(max(l))
# 2

print(max(l, key=abs))
# -3

print(min(l))
# -3

print(min(l, key=abs))
# 1
source: key_usage.py

Note that key is a keyword-only argument, so it must always be specified as key=xxx.

The following examples use sorted(), but the usage of the key argument is the same in sort(), max(), min(), and so on.

Consider a list of strings. By default, such a list is sorted alphabetically. However, if you want to sort it based on the length of the strings, you can specify len() as the key argument.

l_str = ['bbb', 'c', 'aa']

print(sorted(l_str))
# ['aa', 'bbb', 'c']

print(sorted(l_str, key=len))
# ['c', 'aa', 'bbb']
source: key_usage.py

Lambda expressions or custom functions for the key argument

For the key argument, you can provide not just built-in functions, but also lambda expressions or custom functions defined using def.

Consider a two-dimensional list (list of lists). When comparing lists, the first elements are compared.

l_2d = [[2, 10], [1, -30], [-3, 20]]

print(sorted(l_2d))
# [[-3, 20], [1, -30], [2, 10]]
source: key_usage.py

By specifying max() for the key argument, lists are sorted based on their maximum values.

print(sorted(l_2d, key=max))
# [[1, -30], [2, 10], [-3, 20]]
source: key_usage.py

If you want to sort by the maximum absolute value of each list, use a lambda expression.

print(sorted(l_2d, key=lambda x: max([abs(i) for i in x])))
# [[2, 10], [-3, 20], [1, -30]]
source: key_usage.py

For lists with a large number of elements, you can reduce memory usage by using a generator expression with max().

print(sorted(l_2d, key=lambda x: max(abs(i) for i in x)))
# [[2, 10], [-3, 20], [1, -30]]
source: key_usage.py

You can define a function with def instead of a lambda expression and specify it for key.

def max_abs(x):
    return max(abs(i) for i in x)

print(sorted(l_2d, key=max_abs))
# [[2, 10], [-3, 20], [1, -30]]
source: key_usage.py

operator.itemgetter() for the key argument

itemgetter() from the operator module in the standard library returns a callable object that fetches a list element or dictionary value.

operator.itemgetter() for lists

You can sort a two-dimensional list (list of lists) according to the value of any position (index) with operator.itemgetter().

import operator

l_2d = [[2, 10], [1, -30], [-3, 20]]

print(sorted(l_2d, key=operator.itemgetter(1)))
# [[1, -30], [2, 10], [-3, 20]]

You can do the same with a lambda expression.

print(sorted(l_2d, key=lambda x: x[1]))
# [[1, -30], [2, 10], [-3, 20]]

operator.itemgetter() is faster than a lambda expression.

The result of a simple comparison of processing speed between operator.itemgetter() and a lambda expression is described at the end.

operator.itemgetter() for dictionaries

operator.itemgetter() can also be applied to dictionaries (dict).

Consider a list of dictionaries with a common key. By default, dictionaries cannot be directly compared, raising an error. However, you can use operator.itemgetter() to sort the list based on the value of a given key.

l_dict = [{'k1': 2, 'k2': 10}, {'k1': 1}, {'k1': 3}]

# print(sorted(l_dict))
# TypeError: '<' not supported between instances of 'dict' and 'dict'

print(sorted(l_dict, key=operator.itemgetter('k1')))
# [{'k1': 1}, {'k1': 2, 'k2': 10}, {'k1': 3}]

Note that an error is raised if a dictionary without the specified key is included.

# print(sorted(l_dict, key=operator.itemgetter('k2')))
# KeyError: 'k2'

You can do the same with a lambda expression.

print(sorted(l_dict, key=lambda x: x['k1']))
# [{'k1': 1}, {'k1': 2, 'k2': 10}, {'k1': 3}]

If the dictionary does not have the specified key, you can use the get() method to provide a default value. See the following article for details.

Specify multiple arguments to operator.itemgetter()

If multiple arguments are specified to operator.itemgetter(), a tuple containing the result of each is returned.

l_dict = [{'k1': 2, 'k2': 'ccc'}, {'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}]

print(operator.itemgetter('k1', 'k2')(l_dict[0]))
# (2, 'ccc')

Tuples, like lists, are compared in order starting from the first element.

print(sorted(l_dict, key=operator.itemgetter('k1')))
# [{'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}]

print(sorted(l_dict, key=operator.itemgetter('k1', 'k2')))
# [{'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}, {'k1': 2, 'k2': 'ccc'}]

print(sorted(l_dict, key=operator.itemgetter('k2', 'k1')))
# [{'k1': 2, 'k2': 'aaa'}, {'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'ccc'}]

You can also do the same with a lambda expression.

print(sorted(l_dict, key=lambda x: (x['k1'], x['k2'])))
# [{'k1': 1, 'k2': 'ccc'}, {'k1': 2, 'k2': 'aaa'}, {'k1': 2, 'k2': 'ccc'}]

operator.attrgetter() for the key argument

operator.attrgetter() returns a callable object that fetches an attribute.

Consider a list of datetime.date objects. You can get the day, month, and year with the day, month, and year attributes.

import datetime

l_dt = [datetime.date(2003, 2, 10), datetime.date(2001, 3, 20), datetime.date(2002, 1, 30)]

print(l_dt[0])
# 2003-02-10

print(l_dt[0].day)
# 10

f = operator.attrgetter('day')
print(f(l_dt[0]))
# 10

By default, they are sorted by date, but you can sort by any attribute with operator.attrgetter().

print(sorted(l_dt))
# [datetime.date(2001, 3, 20), datetime.date(2002, 1, 30), datetime.date(2003, 2, 10)]

print(sorted(l_dt, key=operator.attrgetter('day')))
# [datetime.date(2003, 2, 10), datetime.date(2001, 3, 20), datetime.date(2002, 1, 30)]

Although operator.attrgetter() is faster, it can also be done with a lambda expression.

print(sorted(l_dt, key=lambda x: x.day))
# [datetime.date(2003, 2, 10), datetime.date(2001, 3, 20), datetime.date(2002, 1, 30)]

operator.methodcaller() for the key argument

operator.methodcaller() returns a callable object that calls a method.

Consider the find() method, which returns the position of a given string.

l_str = ['0_xxxxA', '1_Axxxx', '2_xxAxx']

print(l_str[0])
# 0_xxxxA

print(l_str[0].find('A'))
# 6

f = operator.methodcaller('find', 'A')
print(f(l_str[0]))
# 6

Strings are sorted alphabetically by default, but you can sort them based on the results of any method using operator.methodcaller().

print(sorted(l_str))
# ['0_xxxxA', '1_Axxxx', '2_xxAxx']

print(sorted(l_str, key=operator.methodcaller('find', 'A')))
# ['1_Axxxx', '2_xxAxx', '0_xxxxA']

Although operator.methodcaller() is faster, it can also be done with a lambda expression.

print(sorted(l_str, key=lambda x: x.find('A')))
# ['1_Axxxx', '2_xxAxx', '0_xxxxA']

Speed comparison between lambda expression and operator.itemgetter().

This section shows the results of a simple speed comparison between lambda expressions and operator.itemgetter().

Consider a list of dictionaries with a common key (10000 elements) as an example.

import operator

l = [{'k1': i} for i in range(10000)]

print(len(l))
# 10000

print(l[:5])
# [{'k1': 0}, {'k1': 1}, {'k1': 2}, {'k1': 3}, {'k1': 4}]

print(l[-5:])
# [{'k1': 9995}, {'k1': 9996}, {'k1': 9997}, {'k1': 9998}, {'k1': 9999}]

The following code uses the Jupyter Notebook magic command %%timeit. Note that this will not work if run as a normal Python script.

%%timeit
sorted(l, key=lambda x: x['k1'])
# 1.09 ms ± 35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sorted(l, key=operator.itemgetter('k1'))
# 716 µs ± 28.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sorted(l, key=lambda x: x['k1'], reverse=True)
# 1.11 ms ± 41.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
sorted(l, key=operator.itemgetter('k1'), reverse=True)
# 768 µs ± 58.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
max(l, key=lambda x: x['k1'])
# 1.33 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
max(l, key=operator.itemgetter('k1'))
# 813 µs ± 54.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
min(l, key=lambda x: x['k1'])
# 1.27 ms ± 69.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
min(l, key=operator.itemgetter('k1'))
# 824 µs ± 83.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

operator.itemgetter() is faster than a lambda expression for all the functions: sorted(), max(), and min().

Of course, the results may vary depending on the environment and conditions (number of elements, etc.).

Related Categories

Related Articles