Argmax in Python

Computing the argmax is a common programming problem: you have a set of candidates C, a score function f, and you want to compute an optimal candidate c in C according to f, i.e. one that maximizes f(c). If there is only one such candidate, it is called the "argmax" of f on C. One can find solutions for this on-line, e.g. in this post, but I realized recently that Python's max() function can actually handle the job. Here is how.

So, we have a set of candidates and we want to find the one that maximizes a score function, for instance:

In [1]: candidates = [1, 5, 8, 11, 9]
In [2]: f = lambda x: (x - 4) * (x - 10) * (x - 11)
In [3]: [f(x) for x in candidates]
Out[3]: [-270, 30, 24, 0, 10]

In this case the answer is 5, the candidate at which f is maximal with f(5)=30. A first solution could be to compute the values of f on candidates and use list.index() to return the index of a maximal element:

In [1]: l = [f(x) for x in candidates]
In [2]: fmax = max(l)
In [3]: best_candidate = candidates[l.index(fmax)]
In [4]: best_candidate
Out[4]: 5

This approach relies on the fact that the indexes in l and candidates are the same, and it is a little bit verbose. There is a one-liner that circumvents this:

In [5]: max(candidates, key=f)
Out[5]: 5

In details: the key argument to max() is a one-argument function applied to the elements of the iterable before comparing them (its default value is of course the identity). For instance, max(i, j, key=f) is equivalent to i if f(i) >= f(j) else j.

In short:

argmax = lambda iterable, func: max(iterable, key=func)
Pages of this website are under the CC-BY 4.0 license.