List Comprehensions

List comprehensions are one of the best features of the Python language. They are also occasionally puzzling to new programmers, and often go unused by more experienced ones as well. However, when used properly, list comprehensions can not only make your code easier to understand, but also much faster.

Section 5.1.4 of the Python tutorial looks at list comprehensions briefly, but doesn't do much to outline how to use them; it just gives some syntax.

(Since I wrote this many changes have been made to Python list comprehensions in 2.3 and 2.4; 2.4 even adds a new kind of comprehension based on generators. As always, trust your profiler more than what you read here.)

History

Part of the reason list comprehensions seem foreign to people is because you don't find analogues of them in popular languages like C or Java; they have their root in the LISP functions map and filter (which Python also implements). Understanding list comprehensions means understand map and filter, so we'll start there.

lambda

Important to both map and filter is the concept of a lambda function. A lambda function is (roughly) a function without a name, that does something simple. For example, (lambda x: x * 2) is a lambda function that multiplies its argument by 2. lambda lets you write short functions directly into the code, rather than having to make a def block.

In Python, lambda functions must be expressions. This means you can't use = (assignment) or keywords like print in them.

map

The map function transforms a list by passing each element of it to a function, often a lambda function. Let's say I have the list l = [1, 2, 3, 4] and I want the squares of it. I can say map((lambda x: x * x), l) and get [1, 4, 9, 16] out. You don't need to use a lambda function; I could write def square(x): return x * x and then call map(square, l). This is even a little faster than using the lambda, but less readable since you have to track down the square function elsewhere in the source.

There's also a little trick you can do with map. map(None, l, m, ... z) will return a list of tuples [(l[0], m[0], ... z[0]), (l[1], m[1]... z[1]), .... This is like the zip function, with one important difference — When zip hits the end of a list, it stops. When map hits the end of one list, it puts None objects where the next element of that list would go in the tuple. So you can use map to pair up lists of different sizes, which you can't do with zip.

filter

The filter does something similar, except instead of transforming the list, it removes elements from it. filter(function, lst) will return only those elements of lst (in the same order) that function returns true for. filter((lambda s: s[0] == "b"), ["foo", "bar", "baz"]) returns ["bar", "baz"], the strings beginning with "b".

Map + Filter

That's what a list comprehension is. While it's perfectly reasonable to do map(f1, filter(f2, my_list)), it's slow, you go through the list twice. So instead, we introduce the list comprehension. [f1(x) for x in my_list if f2(x)]. Only you don't need to type the lambda stuff (or a named function) either. Let's say I want the squares of all numbers above 20 in my list. [x * x for x in my_list if x > 20]. That would be a mouthful in lambda, and even more with named functions.

You can leave out the if clause, if you want. [x * x for x in my_list] squares all numbers in the list. If you just want to filter, [x for x in my_list if x > 20] works.

There's a gotcha in the syntax for list comprehensions. If the result of the expression would be a list or tuple, you need to put it in brackets ([]) or parentheses (()).

Optimization

Function calls in Python are slow. In fact, they're one of the slowest things in Python. This goes for both named functions and lambdas. So here's where list comprehensions win big — they don't use a function! The expression in a list comprehension doesn't get treated like a regular function, so it's much faster.

When map and filter Win

List comprehensions aren't always faster. Sometimes there's no way around the function call. If I want to convert a list of numbers to strings, I can do [str(i) for i in my_numbers]. But it's a lot faster, and easier to read, if I use map(str, my_numbers). List comprehensions aren't free, they do have more overhead than map. Since you have to use a function call anyway, you might was well just use map.

This doesn't always work. Let's say I want to filter out all objects in a sequence that aren't strings. Python has an isinstance function we need to use, so we're stuck calling a function. Might as well use filter, right? Well, the code is filter((lambda x: isinstance(x, str)), my_things. We can't tell filter to send in two arguments, so we're stuck writing a lambda anyway! This means two function calls, one for the lambda, and one as the lambda calls isinstance. The list comprehension equivalent, [x for x in my_things if isinstance(x, str)] can be faster in this case.

The former benefit isn't exclusive to map, and the latter problem isn't exclusive to filter. When considering map/filter vs. a list comprehension, always look at the functions you're mapping or filtering with. Depending on the function, either one could win. If you're not sure, try each one and profile your code. Do that anyway. You can't profile your code too much.

Example: DSU

Here's a real example of list comprehensions and their wins over (in this case) map. There's an operation called a Schwartzian transform in the Perl community, and Decorate, Sort, Undecorate among people who don't want to admit that Perl programmers ever designed anything useful. Let's assume you want to sort something awkward; say, a list of lists, sorted by the sum of all the numbers in each. The obvious way is my_list.sort((lambda a, b: cmp(sum(a), sum(b)))) (sum was introduced in Python 2.3).

Students of sorting will know that this takes time approximately proportional to n * log(n) (n is the length of the list), which is the best you can do with sorting (students of Python will know that it's actually proportional to n2 time, in many cases, but that's unimportant). This means sum is going to be called more than n times, even though there are only n elements, and they won't change while the list is being sorted. DSU is a solution to this problem.

The idea of DSU is to precalculate the sum of each list (decorate), sort by that, and then get the actual lists back out (undecorate). Since sum only gets called n times now, DSU is faster than the naive sort. In Python, we can do decoration quickly by making a tuple of (decorated_value, real_value).

map

decorated = map((lambda l: (sum(l), l)), my_list)
decorated.sort()
my_list = map((lambda l: l[1]), decorated)

List Comprehensions

decorated = [(sum(l), l) for l in my_list]
decorated.sort()
my_list = [l[1] for l in decorated]

There you go. Both are very short, and both are in fact faster than the naive sort, But note how the map method makes twice the length of the list in function calls, while the list comprehension makes none at all.

Incidentally, if you just want to sum each list in a list, and don't care about sorting them, map wins. map(sum, my_list).

Feedback

Email piman (at) sacredchao.net with any suggestions. Don't email me asking for me to optimize your code for you, or asking me whether or not a list comprehension would be faster than map for your specific problem; that's what the Python profiler is for.