Tail recursion in python

So I've had a very recursive problem that needed to be solved in python. Now as we know, python does not support tail recursion, so if your problem is a wee bit too complex, you're running out of space.

Luckily, it's quite easy to hack up some kind of workaround, so this is exactly what I did. I'm sure there are cleaner / better solutions out there, but half the fun is doing it yourself...

My solution is based a decorator and a special function to tell "the system" that what we call is indeed a tail call. So there's no auto-detection of tail calls, you'd have to tell the system yourself. Maybe I'll extend the solution to do auto-detection, but this will require fiddling with the python bytecode, which I've never really done before. We'll see.

Alright, so here's what a fully tail-recursive fibonacci function would look like with my hack:

@tail_callable
def fib(a, b):
    print(a)
    return tail_call(fib, b, a + b)

You see the decorator, and the call to our helper function that wraps the actual recursion. You can run this function forever, and it will never run out of stack frames!

And here's how it's implemented - quite simply, we define our own exception type (that inherits from BaseException, which the user should NEVER catch); a decorator, and our "call wrapper":

from functools import wraps


class TailCall(BaseException):
    def __init__(self, fn, args, kwargs):
        self.fn     = fn
        self.args   = args
        self.kwargs = kwargs


def tail_call(fn, *args, **kwargs):
    """Call this instead of directly recursing"""
    raise TailCall(fn, args, kwargs)


def tail_callable(fn):
    """Decorator to make a function tail-callable."""
    @wraps(fn)
    def wrapper(*args, **kwargs):
        # Note: it's tempting to try to detect tail calls by analyzing the
        # stack frame. However you cannot decide if something is just a call
        # from one tail_callable to the next, or if it's a "true" tail call.
        # We'd need to analyze the bytecode of the caller for that.

        next_fun = fn
        while True:
            try:
                # If we return (or have a non-TailCall exception), there's no
                # more recursion or tail calling
                return next_fun(*args, **kwargs)
            except TailCall as call:
                args     = call.args
                kwargs   = call.kwargs
                next_fun = call.fn
                if hasattr(next_fun, 'unwrapped'):
                    next_fun = next_fun.unwrapped

    wrapper.unwrapped = fn

 return wrapper

And now, go and have some fun :)

Update (2016-05-10)

Here's some bit of code to compare the naive variant versus the "workaround" version:

@tail_callable
def fib2(a, b):
    print(a)

    return tail_call(fib2, b, a + b)


def fib1(a, b):
    print(a)
    return fib1(b, a + b)


if __name__ == '__main__':

    modes = {
        'naive': fib1,
        'tail':  fib2
    }
    if len(sys.argv) < 2 or sys.argv[1] not in modes:
        print("Usage: %s {naive|tail}" % sys.argv[0])
        sys.exit(1)

    mode = modes[sys.argv[1]]

    try:
        mode(1, 1)
    except RecursionError as e:
        print("Recursion limit reached")