Decorators look like magic when you first see them, but they don’t feel like something mere mortals could write. We’ll go through some examples of exactly how they work (without the scary ‘@‘ syntax), and why they are a nice pattern for writing elegant, easy to understand code.
The following is adapted from a talk I gave at STL Python on January 6, 2015.
We need some code to work with, so let's go with something familiar. The recusive definition of fibonacci numbers should be familiar to most coders.
def fib(n):
"""Give the nth fibonacci number.
This code runs in O(2^n) time, for illustration purposes"""
if 0 <= n <= 1:
return n
return fib(n - 1) + fib(n - 2)
Let's try out a few test cases, to see how the performance is.
for ix in xrange(10):
print ix, fib(ix)
%time fib(20)
%time fib(30)
%time fib(35)
Five seconds and we're only at n=1
! This is not going to
cut it. The reason (as you may already know), is that we're repeating
a lot of work. For example, fib(5)
calls
fib(4)
and fib(3)
, but because we don't save
fib(3)
, we have to compute it again when
fib(4)
needs it. This gives us a runtime of
O(2^n)
, very slow.
Rather than re-compute all that work, let's save the results.
fib(3)
is going to be 2
no matter who asks,
so we can save the number each time. This is called 'memoization'.
fib_table = {} #somewhere to store the answers we've already computed
def memo_fib(n):
"""Give the nth fibonacci number.
This code is much faster than before."""
if n in fib_table:
return fib_table[n]
if 0 <= n <= 1:
return n
result = memo_fib(n - 1) + memo_fib(n - 2)
fib_table[n] = result #make sure to save this for the next call
return result
Let's see how the runtime compares with our unmemoized version.
%time memo_fib(20)
%time memo_fib(30)
%time memo_fib(35)
%time memo_fib(350)
%time memo_fib(1000)
Whoa, what a change! In fact, memoization brings the runtime down to
O(n)
, much better than O(2^n)
.
We might want to use memoization on other occasions. Decorators can help us with this.
What is the memoization pattern? Basically you keep a dictionary of the results of each call. When we make a new call to the function, first check if it's in our table of saved results. If we do have to do work, save the answer for next time.
Here's the punchline: we can factor out this idea, and make a
memoized
higher-order function.
def memoized(func):
"""Given a function of one argument,
return a memoized version of that function.
"""
lookup_table = {}
def memoized_version(arg):
if arg in lookup_table:
return lookup_table[arg]
result = func(arg)
lookup_table[arg] = result
return result
return memoized_version
Okay, let's test this out on a function.
import time, sys
def greet(name):
"""Print out a personalized greeting.
(personalization takes time).
"""
print 'hold on, let me think...'
time.sleep(2)
return 'Hi there, {}'.format(name)
%time greet('Brian')
%time greet('Dan')
memoized_greet = memoized(greet)
%time memoized_greet('Dan') # The first call is still slow
%time memoized_greet('Dan') # Subsequent calls are quick!
So memoization is great, but I don't really want to have two versions of my function around. Let's overwrite the original and just always use memoization.
greet = memoized(greet)
%time greet('Dan')
%time greet('Dan')
Remember our slow O(2^n)
implementation of fibonacci from
above? We can memoize it with no code changes.
fib = memoized(fib)
fib(50)
fib(200)
fib(500)
fib(800)
And actually, that's decorators. That's exactly what decorators do. They overwrite a function with a modified version of the same function. So we could have written greet this way:
@memoized
def greet(name):
"""Print out a personalized greeting.
(personalization takes time)."""
print 'hold on, let me think...'
time.sleep(2)
return 'Hi there, {}'.format(name)
# Exactly the same as...
def greet(name):
"""Print out a personalized greeting.
(personalization takes time)."""
print 'hold on, let me think...'
time.sleep(2)
return 'Hi there, {}'.format(name)
greet = memoized(greet)
@memoized
def fib(n):
"""Give the nth fibonacci number.
Because of memoization, this code runs in O(n) time
(but it still might blow your stack)"""
if 0 <= n <= 1:
return n
return fib(n - 1) + fib(n - 2)
# Exactly the same as...
def fib(n):
"""Give the nth fibonacci number.
Because of memoization, this code runs in O(n) time
(but it still might blow your stack)"""
if 0 <= n <= 1:
return n
return fib(n - 1) + fib(n - 2)
fib = memoized(fib)
Okay, that's cool, but let's see another example. Say we wanted to time the execution of a function. Decorators allow us to write code that runs before and after the regular function, so this shouldn't be too bad. We'll need some code to time, so let's use a prime checker.
#@timed # This doesn't exist yet...
def is_prime(n):
potential_divisor = 2
while potential_divisor ** 2 <= n:
if n % potential_divisor == 0:
return False
potential_divisor += 1
return True
is_prime(618)
from datetime import datetime
def timed(func):
def timed_version(n):
start = datetime.now()
result = func(n)
print 'that took', datetime.now() - start
return result
return timed_version
Now that we've written a definition for timed
, we can use
it in our definition of is_prime
.
@timed
def is_prime(n):
potential_divisor = 2
while potential_divisor ** 2 <= n:
if n % potential_divisor == 0:
return False
potential_divisor += 1
return True
is_prime(2**31 - 1)
So that's great, but we swept a small detail under the rug. This will
only work for functions of a single parameter (called n)! When we
wrote our timed_version
, we assumed it would accept
n
as its only parameter.
This seems like a violation of reuse — our timing decorator should be
general. It shouldn't work only for functions with parameter, let
alone functions with one parameter called n
.
Suppose we wanted to change the parameters to is_prime
,
maybe also passing in give_factors
, like this:
@timed
def is_prime(num, give_factors):
"""Determine whether a number is prime.
num - the possibly prime number
give_factors - True/False, whether to include the factors
"""
potential_divisor = 2
factors = []
while potential_divisor ** 2 <= num:
if num % potential_divisor == 0:
factors.append((potential_divisor, num / potential_divisor))
potential_divisor += 1
num_is_prime = len(factors) == 0
if give_factors:
return num_is_prime, factors
else:
return num_is_prime
is_prime(2**31 - 1, True) # This will crash because
# timed_version is expecting only a single parameter
We'll rewrite this in a moment to be more general, but first, we have to talk about another confusing aspect of python — star-args
*args, **kwargs
*args
and **kwargs
are the packing/unpacking
operators. A single *
is used to turn a list to/from a
series of positional parameters. A double **
does the
same, but for dictionaries and keyword arguments.
Let's do an example.
def add(a, b):
return a + b
operands = [3, 4]
add(*operands) # --> add(operands[0], operands[1])
The *operands
gets turned into a series of positional
parameters, one for each item in operands
.
What if we use the *
in the function definition?
def add(*args):
total = args[0]
for addend in args[1:]:
total += addend
return total
This new version of add
accepts any number of parameters.
They all packed into the args
variable.
add(1, 2, 3, 4, 5)
Suppose we wanted to use both packing and unpacking? We can!
*operands
here will be unpacked to make a positional
argument for each element of the list operands
. Then, in
the body of the add
function, the positional arguments
have all been packed into a single variable, args
.
operands = [1, 2, 3, 4, 5]
add(*operands) # --> add(operands[0], operands[1], ..., operands[4])
So that's packing and unpacking for positional parameters. What about keyword parameters? We'll use the following function for illustration.
def birthday(name, age):
return 'Happy Birthday, {}! congrats on number {}.'.format(name, age)
params = {'name': 'Batman', 'age': 31}
print birthday(**params) # --> birthday(name='Batman', age=31)
the params
dictionary gets unpacked into keyword args and
passed into birthday.
This can be useful if we have a whole list of folks we want to wish a Happy Birthday:
birthday_list = [
{'name': 'Robert', 'age': 12},
{'name': 'Steven', 'age': 26},
{'name': 'Cassie', 'age': 39}
]
for bday in birthday_list:
print birthday(**bday) # --> birthday(name='Robert', age=12)
Just like with the positional parameters, there is a packing operator as well. In this case, we'll illustrate with a reimplementation of the new-style string format function.
def strformat(template, **kwargs):
for key, val in kwargs.items():
template = template.replace('{'+key+'}', val)
return template
Any keyword parameters that are passed into the
strformat
function get packaged up into the
kwargs
dictionary. So strformat
doesn't even
need to know the names of the parameters you're going to use!
strformat('Happy {occasion}, {name}',
occasion='Thanksgiving', name='Diane')
Okay, excellent. So, *args, **kwargs
let us work with
functions that are really general.
This helps us with writing decorators because we want to be able to
accept any set of parameters, without knowing or specifying them in
advance. Let's revisit our timed
decorator as an example.
def timed(func):
def timed_version(*args, **kwargs):
start = datetime.now()
result = func(*args, **kwargs)
print 'that took', datetime.now() - start
return result
return timed_version
Excellent. None of this "you must use a function with one parameter
and it must be called n
" business.
@timed
def is_prime(num, give_factors):
"""Determine whether a number is prime.
num - the possibly prime number
give_factors - True/False, whether to include the factors
"""
potential_divisor = 2
factors = []
while potential_divisor ** 2 <= num:
if num % potential_divisor == 0:
factors.append((potential_divisor, num / potential_divisor))
potential_divisor += 1
num_is_prime = len(factors) == 0
if give_factors:
return num_is_prime, factors
else:
return num_is_prime
is_prime(2**31 - 1, True) # This time it actually works!
More examples
We'll make a cached
decorator, that remembers the result
it computed, but only for a short time. We could do our fancy
*args, **kwargs
business, and make kind of a
cached_and_memoized
decorator, but let's keep it simple.
from datetime import datetime, timedelta
def cached(func):
"""Given a function of no arguments, return a version
that caches its value for ten seconds."""
cache = {}
def cached_version():
if 'expiry' in cache and cache['expiry'] > datetime.now():
return cache['value']
else:
result = func()
cache['value'] = result
cache['expiry'] = datetime.now() + timedelta(seconds=10)
return result
return cached_version
Now to use the decorator
from IPython.display import Image
import requests
@cached
def cat_picture():
print 'hitting api'
resp = requests.get('http://edgecats.net/random')
return Image(url=resp.text)
Now we'll see if the cache is actually working. We'll run
cat_picture()
twice in a row, and we should see the same
picture. Then, we'll sleep for 10 seconds and we should get a new
picture.
cat_picture()
cat_picture()
time.sleep(10)
cat_picture()
Nice!
Let's say we wanted to make a more general
cached
decorator. We want be able to specify how long to
wait before the cache is invalid.
This means we need one more layer of nesting. Basically, we need to write a function that returns the decorator. This is a bit complicated, so read this code carefully.
def cached(**kwargs):
"""return a decorator that caches for the specified time."""
wait = timedelta(**kwargs)
def cached_decorator(func):
"""Given a function of no arguments, return a version
that caches its value for <wait> seconds."""
cache = {}
def cached_version():
if 'expiry' in cache and cache['expiry'] > datetime.now():
return cache['value']
else:
result = func()
cache['value'] = result
cache['expiry'] = datetime.now() + wait
return result
return cached_version
return cached_decorator
We can use this decorator like so:
@cached(seconds=2)
def cat_picture():
print 'hitting api'
resp = requests.get('http://edgecats.net/random')
return Image(url=resp.text)
But hold on. This is starting to feel like magic again. Let's take
apart that construction and see what it's really doing. We'll use our
knowledge of what's really happening with that @
sign.
def cat_picture():
print 'hitting api'
resp = requests.get('http://edgecats.net/random')
return Image(url=resp.text)
cat_picture = (cached(seconds=2))(cat_picture)
# Or, even clearer
cached_for_two_seconds = cached(seconds=2) # create our
# decorator, setting the wait time
def cat_picture():
print 'hitting api'
resp = requests.get('http://edgecats.net/random')
return Image(url=resp.text)
cat_picture = cached_for_two_seconds(cat_picture)
cat_picture()
cat_picture()
time.sleep(2)
cat_picture()
Imitating Flask
Flask is an excellent web framework that registers its views using decorators. We can mimic this interface to understand how it works.
class Tube(object):
"""A Tube is a small Flask, right?"""
def __init__(self):
self.endpoints = {}
def route(self, url):
def register_endpoint(func):
self.endpoints[url] = func
return func
return register_endpoint
app = Tube()
@app.route('/api/cat_picture')
def get_cat_pic():
return 'blah'
# or, desugared
route_to_dog_pic = app.route('/api/doc_picture')
def get_dog_pic():
return 'slaw'
get_dog_pic = route_to_dog_pic(get_dog_pic)
app.endpoints
Bonus: Synchronization
Suppose you have multithreaded program, and you want to protect the critical section with a mutex so no two threads can be there at the same time. We can do this with a decorator!!
import threading
def synchronized(func):
mutex = threading.Semaphore()
def critical_section(*arg, **kwargs):
mutex.acquire()
func(*arg, **kwargs)
mutex.release()
return critical_section
We'll use an example from Lord of the Flies to illustrate what's going on. First, without synchronization.
#@synchronized Let's do this first without
def claim_speaking_priviledges(claimant):
print "{}: I've got the conch!".format(claimant)
time.sleep(1)
print "{}: Okay, you can talk now.".format(claimant)
speakers = []
for speaker in ('Ralph', 'Piggy', 'Jack', 'Roger'):
t = threading.Thread(target=claim_speaking_priviledges,
kwargs={'claimant':speaker})
speakers.append(t)
t.start()
for s in speakers:
s.join()
@synchronized
def claim_speaking_priviledges(claimant):
print "{}: I've got the conch!".format(claimant)
time.sleep(1)
print "{}: Okay, you can talk now.".format(claimant)
speakers = []
for speaker in ('Ralph', 'Piggy', 'Jack', 'Roger'):
t = threading.Thread(target=claim_speaking_priviledges,
kwargs={'claimant':speaker})
speakers.append(t)
t.start()
for s in speakers:
s.join()
Much more orderly.
Hopefully this illustrates the power of decorators, and lessens some of their scariness. I find that decorators can often be applied when you need
-
code that runs before and after each invocation of a function (like
timed
) -
some metadata you want to store along with a function (like the
lookup table in
memoized
) -
registering a function in a table (like Flask and Tube's
app.route
) -
permission-checking (like Django's
login_required
, which we didn't talk about)