From wikipedia:

As used in some Lisp implementations, a trampoline is a loop that iteratively invokes thunk-returning functions (continuation-passing style). A single trampoline suffices to express all control transfers of a program; a program so expressed is trampolined, or in trampolined style; converting a program to trampolined style is trampolining. Programmers can use trampolined functions to implement tail-recursive function calls in stack-oriented programming languages

In this post I explain what a trampoline is and why it matters.

Consider this recursive python function

1
2
3
4
5
def tail_recursive_f(n):
    if n == 0:
        return "done"
    print n
    return f(n-1)

The argument n gives us the number of recursive calls performed by tail_recursive_f. Can you invoke tail_recursive_f with any value for its argument n? In python, and in the stack-oriented programming languages wikipedia talks about, the answer is no. At each recursive step, the function parameters and the return address are pushed onto the stack, which is a limited resource. Thus, big values of n break the stack.

1
2
3
4
>>> tail_recursive_f(10000)
10000
...
RuntimeError: maximum recursion depth exceeded

Python has a safety-guard: if the recursion depth level exceeds a hardcoded value, it triggers a RuntimeError exception.

What if you want to stop caring about the stack? If your function is tail-recursive, like tail_recursive_f, you can write a trampolined version of it.

The objective is to 1) store the recursive calls somewhere for future execution, 2) extract them from the function, and 3) run them in a loop.

To postpone a computation, we usually wrap it in a function that performs that computation when invoked. That function is often called thunk.

1
2
3
4
5
def operation():
    return 6*10

def lazy_operation():
    return lambda: 6*10
1
2
3
4
5
6
7
>>> operation()
60
>>> lazy_operation():
<function <lambda> at 0x7ffa2753c7d0>
>>> thunk=lazy_operation()
>>> thunk()
60

Here, we transform tail_recursive_f in the same way: we wrap the recursive call in a lambda.

1
2
3
4
5
def tail_recursive_f(n):
    if n == 0:
        return "done"
    print n
    return lambda: f(n-1)

This version of tail_recursive_f runs one step, and then it returns a function containing the continuation.

1
2
3
>>> tail_recursive_f(10000)
10000
<function <lambda> at 0x7ffa2753c758>

We can store it in a variable, and call it to execute the next step and get the next continuation.

1
2
3
4
5
>>> thunk = tail_recursive_f(10000)
10000
>>> thunk()
9999
<function <lambda> at 0x7ffa2753c578>

and so on.

1
2
3
4
>>> thunk = thunk()
9999
>>> thunk = thunk()
9998

At the end, the last thunk will receive n=0, and it won’t return a callable, but the string "done".

The tampoline is the loop running the chain of thunks, until they return a callable.

1
2
3
4
5
def trampoline(f, *args, **kwargs):
    g = lambda: f(*args, **kwargs)
    while callable(g):
        g = g()
    return g

The function takes in input the trampolined function, and its arguments, to be able to call it. At the beginning, it wraps it in a new function, to make it callable without parameters. Then, it keeps calling the thunks until they return a thunk, i.e. a callable. At the end, it will have the actual result, and it returns it.

Now tail_recursive_f(10000) is invoked in this way

1
>>> trampoline(tail_recursive_f, 10000)

which prints all the integers between 10000 and 1, before returning "done".

What about a not tail-recursive function? I take the simplest recursive function: the factorial.

1
2
3
4
def factorial(n):
    if n == 0:
        return 1
    return n*factorial(n-1)
1
2
3
4
5
>>> factorial(3)
6
>>> factorial(1000)
...
RuntimeError: maximum recursion depth exceeded

This version of factorial is not tail-recursive: at each recursion level, the execution of the callee must return to the caller because it multiplies its return value by n.

So first of all we turn it into tail recursive. To do this, one technique is to propagate intermediate results through the stack using the function arguments.

1
2
3
4
def tail_recursive_factorial(n, acc=1):
    if n == 0:
        return acc
    return tail_recursive_factorial(n-1, acc=n*acc)

In this way, at each recursion step n, we have all the information to compute the result: there is no need to return to the caller.

Then, we apply the same transformation as before

1
2
3
4
def tail_recursive_factorial(n, acc=1):
    if n == 0:
        return acc
    return lambda: tail_recursive_factorial(n-1, acc=n*acc)

and we are ready for our trampoline

1
2
3
4
>>> trampoline(tail_recursive_factorial, 3)
6
>>> trampoline(tail_recursive_factorial, 1000)
