Decorators without the Magic


Decorators look like magic when you first see them, but they don’t feel like something mere mortals could write. We’ll go through some examples of exactly how they work (without the scary ‘@‘ syntax), and why they are a nice pattern for writing elegant, easy to understand code.

The following is adapted from a talk I gave at STL Python on January 6, 2015.

We need some code to work with, so let's go with something familiar. The recusive definition of fibonacci numbers should be familiar to most coders.

In [1]:
def fib(n):
    """Give the nth fibonacci number.
    
    This code runs in O(2^n) time, for illustration purposes"""
    if 0 <= n <= 1:
        return n
    return fib(n - 1) + fib(n - 2)

Let's try out a few test cases, to see how the performance is.

In [2]:
for ix in xrange(10):
    print ix, fib(ix)
0 0
1 1
2 1
3 2
4 3
5 5
6 8
7 13
8 21
9 34

In [3]:
%time fib(20)
CPU times: user 5.14 ms, sys: 1.67 ms, total: 6.82 ms
Wall time: 5.46 ms

Out[3]:
6765
In [4]:
%time fib(30)
CPU times: user 497 ms, sys: 19.2 ms, total: 516 ms
Wall time: 503 ms

Out[4]:
832040
In [5]:
%time fib(35)
CPU times: user 5.04 s, sys: 14.6 ms, total: 5.05 s
Wall time: 5.1 s

Out[5]:
9227465

Five seconds and we're only at n=1! This is not going to cut it. The reason (as you may already know), is that we're repeating a lot of work. For example, fib(5) calls fib(4) and fib(3), but because we don't save fib(3), we have to compute it again when fib(4) needs it. This gives us a runtime of O(2^n), very slow.

Rather than re-compute all that work, let's save the results. fib(3) is going to be 2 no matter who asks, so we can save the number each time. This is called 'memoization'.

In [6]:
fib_table = {} #somewhere to store the answers we've already computed

def memo_fib(n):
    """Give the nth fibonacci number.
    
    This code is much faster than before."""
    if n in fib_table:
        return fib_table[n]
    if 0 <= n <= 1:
        return n
    result = memo_fib(n - 1) + memo_fib(n - 2)
    fib_table[n] = result #make sure to save this for the next call
    return result

Let's see how the runtime compares with our unmemoized version.

In [7]:
%time memo_fib(20)
CPU times: user 25 µs, sys: 20 µs, total: 45 µs
Wall time: 38.1 µs

Out[7]:
6765
In [8]:
%time memo_fib(30)
CPU times: user 17 µs, sys: 9 µs, total: 26 µs
Wall time: 22.9 µs

Out[8]:
832040
In [9]:
%time memo_fib(35)
CPU times: user 11 µs, sys: 5 µs, total: 16 µs
Wall time: 16.9 µs

Out[9]:
9227465
In [10]:
%time memo_fib(350)
CPU times: user 424 µs, sys: 297 µs, total: 721 µs
Wall time: 544 µs

Out[10]:
6254449428820551641549772190170184190608177514674331726439961915653414425L
In [11]:
%time memo_fib(1000)
CPU times: user 1.12 ms, sys: 550 µs, total: 1.67 ms
Wall time: 1.36 ms

Out[11]:
43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875L

Whoa, what a change! In fact, memoization brings the runtime down to O(n), much better than O(2^n).

We might want to use memoization on other occasions. Decorators can help us with this.

What is the memoization pattern? Basically you keep a dictionary of the results of each call. When we make a new call to the function, first check if it's in our table of saved results. If we do have to do work, save the answer for next time.

Here's the punchline: we can factor out this idea, and make a memoized higher-order function.

In [12]:
def memoized(func):
    """Given a function of one argument,
    return a memoized version of that function.
    """
    lookup_table = {}
    def memoized_version(arg):
        if arg in lookup_table:
            return lookup_table[arg]

        result = func(arg)
        lookup_table[arg] = result

        return result
    return memoized_version

Okay, let's test this out on a function.

In [13]:
import time, sys
def greet(name):
    """Print out a personalized greeting.
    
    (personalization takes time).
    """
    print 'hold on, let me think...'
    time.sleep(2)
    return 'Hi there, {}'.format(name)
In [14]:
%time greet('Brian')
hold on, let me think...
CPU times: user 329 µs, sys: 518 µs, total: 847 µs
Wall time: 2.01 s

Out[14]:
'Hi there, Brian'
In [15]:
%time greet('Dan')
hold on, let me think...
CPU times: user 409 µs, sys: 688 µs, total: 1.1 ms
Wall time: 2.01 s

Out[15]:
'Hi there, Dan'
In [16]:
memoized_greet = memoized(greet)
In [17]:
%time memoized_greet('Dan') # The first call is still slow
hold on, let me think...
CPU times: user 361 µs, sys: 536 µs, total: 897 µs
Wall time: 2 s

Out[17]:
'Hi there, Dan'
In [18]:
%time memoized_greet('Dan') # Subsequent calls are quick!
CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.15 µs

Out[18]:
'Hi there, Dan'

So memoization is great, but I don't really want to have two versions of my function around. Let's overwrite the original and just always use memoization.

In [19]:
greet = memoized(greet)
In [20]:
%time greet('Dan')
hold on, let me think...
CPU times: user 347 µs, sys: 534 µs, total: 881 µs
Wall time: 2.01 s

Out[20]:
'Hi there, Dan'
In [21]:
%time greet('Dan')
CPU times: user 4 µs, sys: 2 µs, total: 6 µs
Wall time: 7.87 µs

Out[21]:
'Hi there, Dan'

Remember our slow O(2^n) implementation of fibonacci from above? We can memoize it with no code changes.

In [22]:
fib = memoized(fib)
In [23]:
fib(50)
Out[23]:
12586269025
In [24]:
fib(200)
Out[24]:
280571172992510140037611932413038677189525L
In [25]:
fib(500)
Out[25]:
139423224561697880139724382870407283950070256587697307264108962948325571622863290691557658876222521294125L
In [26]:
fib(800)
Out[26]:
69283081864224717136290077681328518273399124385204820718966040597691435587278383112277161967532530675374170857404743017623467220361778016172106855838975759985190398725L

And actually, that's decorators. That's exactly what decorators do. They overwrite a function with a modified version of the same function. So we could have written greet this way:

In [27]:
@memoized
def greet(name):
    """Print out a personalized greeting.
    
    (personalization takes time)."""
    print 'hold on, let me think...'
    time.sleep(2)
    return 'Hi there, {}'.format(name)

# Exactly the same as...
def greet(name):
    """Print out a personalized greeting.
    
    (personalization takes time)."""
    print 'hold on, let me think...'
    time.sleep(2)
    return 'Hi there, {}'.format(name)
greet = memoized(greet)
In [28]:
@memoized
def fib(n):
    """Give the nth fibonacci number.
    
    Because of memoization, this code runs in O(n) time
    (but it still might blow your stack)"""
    if 0 <= n <= 1:
        return n
    return fib(n - 1) + fib(n - 2)

# Exactly the same as...
def fib(n):
    """Give the nth fibonacci number.
    
    Because of memoization, this code runs in O(n) time
    (but it still might blow your stack)"""
    if 0 <= n <= 1:
        return n
    return fib(n - 1) + fib(n - 2)
fib =  memoized(fib)

Okay, that's cool, but let's see another example. Say we wanted to time the execution of a function. Decorators allow us to write code that runs before and after the regular function, so this shouldn't be too bad. We'll need some code to time, so let's use a prime checker.

In [29]:
#@timed  # This doesn't exist yet...
def is_prime(n):
    potential_divisor = 2
    while potential_divisor ** 2 <= n:
        if n % potential_divisor == 0:
            return False
        potential_divisor += 1
    return True
In [30]:
is_prime(618)
Out[30]:
False
In [31]:
from datetime import datetime
def timed(func):
    def timed_version(n):
        start = datetime.now()
        result = func(n)
        print 'that took', datetime.now() - start
        return result
    return timed_version

Now that we've written a definition for timed, we can use it in our definition of is_prime.

In [32]:
@timed
def is_prime(n):
    potential_divisor = 2
    while potential_divisor ** 2 <= n:
        if n % potential_divisor == 0:
            return False
        potential_divisor += 1
    return True
In [33]:
is_prime(2**31 - 1)
that took 0:00:00.013842

Out[33]:
True

So that's great, but we swept a small detail under the rug. This will only work for functions of a single parameter (called n)! When we wrote our timed_version, we assumed it would accept n as its only parameter.

This seems like a violation of reuse — our timing decorator should be general. It shouldn't work only for functions with parameter, let alone functions with one parameter called n.

Suppose we wanted to change the parameters to is_prime, maybe also passing in give_factors, like this:

In [34]:
@timed
def is_prime(num, give_factors):
    """Determine whether a number is prime.
      num - the possibly prime number
      give_factors - True/False, whether to include the factors
    """

    potential_divisor = 2
    factors = []
    while potential_divisor ** 2 <= num:
        if num % potential_divisor == 0:
            factors.append((potential_divisor, num / potential_divisor))
        potential_divisor += 1

    num_is_prime = len(factors) == 0
    if give_factors:
        return num_is_prime, factors
    else:
        return num_is_prime
In [36]:
is_prime(2**31 - 1, True) # This will crash because
# timed_version is expecting only a single parameter
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-36-a7edfb41034e> in <module>()
----> 1 is_prime(2**31 - 1, True) # This will crash because
      2 # timed_version is expecting only a single parameter

TypeError: timed_version() takes exactly 1 argument (2 given)

We'll rewrite this in a moment to be more general, but first, we have to talk about another confusing aspect of python — star-args

*args, **kwargs

*args and **kwargs are the packing/unpacking operators. A single * is used to turn a list to/from a series of positional parameters. A double ** does the same, but for dictionaries and keyword arguments.

Let's do an example.

In [37]:
def add(a, b):
    return a + b
In [38]:
operands = [3, 4]
add(*operands) # --> add(operands[0], operands[1])
Out[38]:
7

The *operands gets turned into a series of positional parameters, one for each item in operands.

What if we use the * in the function definition?

In [39]:
def add(*args):
    total = args[0]
    for addend in args[1:]:
        total += addend
    return total

This new version of add accepts any number of parameters. They all packed into the args variable.

In [40]:
add(1, 2, 3, 4, 5)
Out[40]:
15

Suppose we wanted to use both packing and unpacking? We can! *operands here will be unpacked to make a positional argument for each element of the list operands. Then, in the body of the add function, the positional arguments have all been packed into a single variable, args.

In [41]:
operands = [1, 2, 3, 4, 5]
add(*operands) # --> add(operands[0], operands[1], ..., operands[4])
Out[41]:
15

So that's packing and unpacking for positional parameters. What about keyword parameters? We'll use the following function for illustration.

In [42]:
def birthday(name, age):
    return 'Happy Birthday, {}! congrats on number {}.'.format(name, age)
In [43]:
params = {'name': 'Batman', 'age': 31}
print birthday(**params) # --> birthday(name='Batman', age=31)
Happy Birthday, Batman! congrats on number 31.

the params dictionary gets unpacked into keyword args and passed into birthday.

This can be useful if we have a whole list of folks we want to wish a Happy Birthday:

In [44]:
birthday_list = [
    {'name': 'Robert', 'age': 12},
    {'name': 'Steven', 'age': 26},
    {'name': 'Cassie', 'age': 39}
]

for bday in birthday_list:
    print birthday(**bday) # --> birthday(name='Robert', age=12)
Happy Birthday, Robert! congrats on number 12.
Happy Birthday, Steven! congrats on number 26.
Happy Birthday, Cassie! congrats on number 39.

Just like with the positional parameters, there is a packing operator as well. In this case, we'll illustrate with a reimplementation of the new-style string format function.

In [45]:
def strformat(template, **kwargs):
    for key, val in kwargs.items():
        template = template.replace('{'+key+'}', val)
    return template

Any keyword parameters that are passed into the strformat function get packaged up into the kwargs dictionary. So strformat doesn't even need to know the names of the parameters you're going to use!

In [46]:
strformat('Happy {occasion}, {name}',
          occasion='Thanksgiving', name='Diane')
Out[46]:
'Happy Thanksgiving, Diane'

Okay, excellent. So, *args, **kwargs let us work with functions that are really general.

This helps us with writing decorators because we want to be able to accept any set of parameters, without knowing or specifying them in advance. Let's revisit our timed decorator as an example.

In [47]:
def timed(func):
    def timed_version(*args, **kwargs):
        start = datetime.now()
        result = func(*args, **kwargs)
        print 'that took', datetime.now() - start
        return result
    return timed_version

Excellent. None of this "you must use a function with one parameter and it must be called n" business.

In [48]:
@timed
def is_prime(num, give_factors):
    """Determine whether a number is prime.
      num - the possibly prime number
      give_factors - True/False, whether to include the factors
    """

    potential_divisor = 2
    factors = []
    while potential_divisor ** 2 <= num:
        if num % potential_divisor == 0:
            factors.append((potential_divisor, num / potential_divisor))
        potential_divisor += 1

    num_is_prime = len(factors) == 0
    if give_factors:
        return num_is_prime, factors
    else:
        return num_is_prime
In [49]:
is_prime(2**31 - 1, True) # This time it actually works!
that took 0:00:00.014745

Out[49]:
(True, [])

More examples

We'll make a cached decorator, that remembers the result it computed, but only for a short time. We could do our fancy *args, **kwargs business, and make kind of a cached_and_memoized decorator, but let's keep it simple.

In [50]:
from datetime import datetime, timedelta

def cached(func):
    """Given a function of no arguments, return a version
    that caches its value for ten seconds."""
    cache = {}
    def cached_version():
        if 'expiry' in cache and cache['expiry'] > datetime.now():
            return cache['value']
        else:
            result = func()
            cache['value'] = result
            cache['expiry'] = datetime.now() + timedelta(seconds=10)
            return result
    return cached_version

Now to use the decorator

In [51]:
from IPython.display import Image
import requests

@cached
def cat_picture():
    print 'hitting api'
    resp = requests.get('http://edgecats.net/random')
    return Image(url=resp.text)

Now we'll see if the cache is actually working. We'll run cat_picture() twice in a row, and we should see the same picture. Then, we'll sleep for 10 seconds and we should get a new picture.

In [52]:
cat_picture()
hitting api

Out[52]:
In [53]:
cat_picture()
Out[53]:
In [54]:
time.sleep(10)
cat_picture()
hitting api

Out[54]:

Nice!

Let's say we wanted to make a more general cached decorator. We want be able to specify how long to wait before the cache is invalid.

This means we need one more layer of nesting. Basically, we need to write a function that returns the decorator. This is a bit complicated, so read this code carefully.

In [55]:
def cached(**kwargs):
    """return a decorator that caches for the specified time."""
    wait = timedelta(**kwargs)
    def cached_decorator(func):
        """Given a function of no arguments, return a version
        that caches its value for <wait> seconds."""
        cache = {}
        def cached_version():
            if 'expiry' in cache and cache['expiry'] > datetime.now():
                return cache['value']
            else:
                result = func()
                cache['value'] = result
                cache['expiry'] = datetime.now() + wait
                return result
        return cached_version
    return cached_decorator

We can use this decorator like so:

In [56]:
@cached(seconds=2)
def cat_picture():
    print 'hitting api'
    resp = requests.get('http://edgecats.net/random')
    return Image(url=resp.text)

But hold on. This is starting to feel like magic again. Let's take apart that construction and see what it's really doing. We'll use our knowledge of what's really happening with that @ sign.

In [57]:
def cat_picture():
    print 'hitting api'
    resp = requests.get('http://edgecats.net/random')
    return Image(url=resp.text)
cat_picture = (cached(seconds=2))(cat_picture)

# Or, even clearer

cached_for_two_seconds = cached(seconds=2) # create our
# decorator, setting the wait time

def cat_picture():
    print 'hitting api'
    resp = requests.get('http://edgecats.net/random')
    return Image(url=resp.text)
cat_picture = cached_for_two_seconds(cat_picture)
In [58]:
cat_picture()
hitting api

Out[58]:
In [59]:
cat_picture()
Out[59]:
In [60]:
time.sleep(2)
cat_picture()
hitting api

Out[60]:

Imitating Flask

Flask is an excellent web framework that registers its views using decorators. We can mimic this interface to understand how it works.

In [61]:
class Tube(object):
    """A Tube is a small Flask, right?"""
    def __init__(self):
        self.endpoints = {}
    def route(self, url):
        def register_endpoint(func):
            self.endpoints[url] = func
            return func
        return register_endpoint
In [62]:
app = Tube()

@app.route('/api/cat_picture')
def get_cat_pic():
    return 'blah'

# or, desugared
route_to_dog_pic = app.route('/api/doc_picture')

def get_dog_pic():
    return 'slaw'
get_dog_pic = route_to_dog_pic(get_dog_pic)
In [63]:
app.endpoints
Out[63]:
{'/api/cat_picture': <function __main__.get_cat_pic>,
 '/api/doc_picture': <function __main__.get_dog_pic>}

Bonus: Synchronization

Suppose you have multithreaded program, and you want to protect the critical section with a mutex so no two threads can be there at the same time. We can do this with a decorator!!

In [64]:
import threading
def synchronized(func):
    mutex = threading.Semaphore()
    def critical_section(*arg, **kwargs):
        mutex.acquire()
        func(*arg, **kwargs)
        mutex.release()
    return critical_section

We'll use an example from Lord of the Flies to illustrate what's going on. First, without synchronization.

In [65]:
#@synchronized Let's do this first without
def claim_speaking_priviledges(claimant):
    print "{}: I've got the conch!".format(claimant)
    time.sleep(1)
    print "{}: Okay, you can talk now.".format(claimant)
In [66]:
speakers = []
for speaker in ('Ralph', 'Piggy', 'Jack', 'Roger'):
    t = threading.Thread(target=claim_speaking_priviledges,
                         kwargs={'claimant':speaker})
    speakers.append(t)
    t.start()
for s in speakers:
    s.join()
Ralph: I've got the conch!
Piggy: I've got the conch!
Jack: I've got the conch!
Roger: I've got the conch!
Piggy: Okay, you can talk now.Roger: Okay, you can talk now.
 Jack: Okay, you can talk now.
Ralph: Okay, you can talk now.


In [67]:
@synchronized
def claim_speaking_priviledges(claimant):
    print "{}: I've got the conch!".format(claimant)
    time.sleep(1)
    print "{}: Okay, you can talk now.".format(claimant)
In [68]:
speakers = []
for speaker in ('Ralph', 'Piggy', 'Jack', 'Roger'):
    t = threading.Thread(target=claim_speaking_priviledges,
                         kwargs={'claimant':speaker})
    speakers.append(t)
    t.start()
for s in speakers:
    s.join()
Ralph: I've got the conch!
Ralph: Okay, you can talk now.
Piggy: I've got the conch!
Piggy: Okay, you can talk now.
Jack: I've got the conch!
Jack: Okay, you can talk now.
Roger: I've got the conch!
Roger: Okay, you can talk now.

Much more orderly.

Hopefully this illustrates the power of decorators, and lessens some of their scariness. I find that decorators can often be applied when you need

  • code that runs before and after each invocation of a function (like timed)
  • some metadata you want to store along with a function (like the lookup table in memoized)
  • registering a function in a table (like Flask and Tube's app.route)
  • permission-checking (like Django's login_required, which we didn't talk about)
In []: