Josh Munn's Website

Rewriting a recursive Python function as an abstract machine

Silhouette of a person standing by a tree on a hill

The miniKanren paper (Will Byrd’s thesis) describes walk*, which is used to deeply walk some term with regards to a substitution. That means that if a given term is not atomic (e.g. it’s a list), it will walk all of its sub-terms, recursively. Here’s a naive translation to Python:

def deep_walk(candidate: Any, sub: Substitution) -> Any:
    """
    Like `walk', but process elements of lists.
    """
    candidate = walk(candidate, sub)
    if isinstance(candidate, list):
        return [deep_walk(x, sub) for x in candidate]
    else:
        return candidate

The problem with this recursive approach is that we’re bound by Python’s recursion limit. Terms may be nested arbitrarily deeply, so the recursive calls to deep_walk may blow the stack. When facing this kind of problem, the general advice from Pythonistas is to “rewrite it iteratively”. Let’s take a look at how we might approach that.

One approach is to rewrite the function with an explicit stack that you manually manage, emulating function calls with pushes, and returns with pops. However, the beauty of the recursive approach is that it hides all of that book-keeping away, leaving just a description of the behaviour we really care about. The recursive implementation succinctly describes what it means for a term to be “deeply walked” — it would be a shame to lose that.

Instead, let’s write it as an abstract machine. This won’t absolve us from all book-keeping, and the result will not be nearly as self-descriptive as the recursive approach, but it will keep the focus on the finite states of the program that encapsulate the desired behaviour.

Note (2024-11-26): the context of this post is logic programming and the implementation of miniKanren style languages. However, the content is about programming in general, so I won’t spend any time explaining logic programming specifics.

Defining an abstract machine

Abstract Machine with Three Registers A visual representation of an abstract machine with three registers: Term (lambda), Continuation (circular arrow), and Return (return arrow). Each register is shown as an empty storage area with different shading patterns. ┌───────────────────────────┐ │ λ [░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒] │ │ ⟳ [▓▒░░░▒▓█████▓▒░░░▒▓█] │ │ ↵ [██▓▒░▒▓██▓▒░▒▓██████] │ └───────────────────────────┘

While working on this problem, I was reminded of Jean-Louis Krivine’s paper A call-by-name lambda-calculus machine, which I found while trying to grok justine’s Lambda Calculus in 383 Bytes. It describes an abstract machine for lazy evaluation of lambda calculus terms, which uses only three registers and three operations. We won’t achieve anything near that level of bang-for-buck, but it’s inspiring nevertheless. In particular, the way closures are handled is similar to the way we’ll handle continuations.

We’ll start by building a machine with 3 registers. At any given time in the machine’s life, each register might be holding a value, or it might be empty. We can’t use None or False to represent an empty register, as those are valid terms. We could use a sentinel value, or add another register that describes the machine’s current state, but I like the challenge of keeping the program pure, and minimising the amount of state we capture. Instead, we’ll borrow from functional programming land, and use a Maybe type, defined as follows:

T = TypeVar("T")
U = TypeVar("U")


class Just[T]:
    val: T
    __match_args__ = ("val",)

    def __init__(self, val: T):
        self.val = val

    def __repr__(self) -> str:
        return f"Just({self.val!r})"

    def bind(self, f: Callable[[T], "Maybe[U]"]) -> "Maybe[U]":
        return f(self.val)


class Nothing:
    def __eq__(self, other):
        return type(other) is self.__class__

    def __repr__(self) -> str:
        return "Nothing()"

    def bind(self, _):
        return Nothing()


type Maybe[T] = Just[T] | Nothing

This will work nicely with Python’s structural pattern matching.

Next, we’ll define a Python class that represents the state of the machine.

class MachineState(NamedTuple):
    term: Maybe[Any] = Nothing()
    cont: Maybe["MachineState"] = Nothing()
    ret: Maybe[Any] = Nothing()

The three registers are:

  1. term — the term currently being evaluated;
  2. cont — the current continuation (the suspended state we’ll return to after the current computation); and
  3. ret - the term to return to the continuation (or the final return value of the program).

There are 6 states that the machine needs to handle (which we’ll reduce to 5 cases in the code):

  1. we’re looking at an atomic term;
  2. we’re looking at a new list;
  3. we’re continuing processing a list;
  4. we’ve finished processing a list;
  5. we’ve finished processing a sub-term of a list; and
  6. final state — ready to return a term to the caller.

We’ll implement the machine as two mutually recursive functions: shift (which will process states 1 to 4) and unshift (which will process states 5 and 6).

Processing atomic values

The simplest case is accepting an atomic term and returning it to the caller.

+def deep_walk(term, sub):
+
+    def shift(state: MachineState):
+        print("shift: ", state)
+
+        # Do a shallow walk first — if the term is a bound logic variable,
+        # we'll process its associated term.
+        # bind will operate on the wrapped value of Just instances,
+        # and return Nothing() for Nothings.
+        term = state.term.bind(lambda x: Just(walk(x, sub)))
+        state = MachineState(term, state.cont, state.ret)
+
+        match state:
+            case (term, cont, Nothing()):
+                # Shift the atomic term to the return position.
+                return unshift(MachineState(Nothing(), cont, term))
+
+    def unshift(state: MachineState):
+        print("unshift: ", state)
+
+        match state:
+            case (Nothing(), Nothing(), Just(term)):
+                # Final state: return to caller.
+                return term
+
+    return shift(MachineState(Just(term), Nothing(), Nothing()))

Let’s test it out.

>>> deep_walk(3, empty_sub())
shift:  MachineState(term=Just(3), cont=Nothing(), ret=Nothing())
unshift:  MachineState(term=Nothing(), cont=Nothing(), ret=Just(3))
3

Atomic terms are processed and returned to the caller. This is implemented by taking the value in the term register, shifting it to the ret register, and handing the state off to unshift.

>>> x = Var("x")
>>> deep_walk(x, empty_sub().set(x, "foo"))
shift:  MachineState(term=Just('foo'), cont=Nothing(), ret=Nothing())
unshift:  MachineState(term=Nothing(), cont=Nothing(), ret=Just('foo'))
'foo'

Bound variables are walked.

>>> deep_walk([x], empty_sub().set(x, "foo"))
shift:  MachineState(term=Just([Var('x')]), cont=Nothing(), ret=Nothing())
unshift:  MachineState(term=Nothing(), cont=Nothing(), ret=Just([Var('x')]))
[Var('x')]

Sub terms are not processed, yet. We’ll deal with that next.

Processing compound terms

Processing compound terms (lists, for the sake of this example) is slightly more complex. We need to:

  1. suspend the current evaluation while we process a sub-term;
  2. evaluate a sub-term while a continuation is waiting;
  3. resume the suspended evaluation, adding the evaluated sub-term to the output list; and
  4. handle the end of a list, returning that list to the caller or adding it to an accumulator if it’s part of a nested list.

We need to describe each of these states using our 3 registers, so it will be helpful to establish some rules.

  1. Whenever we process a sub-term of a list, we push a continuation that captures the rest of the list in the term register, and an accumulator for sub-terms that have already been processed (stored in the ret register).
  2. Whenever we unshift, and the continuation is not Nothing() append the ret value of the current state to the accumulator of the continuation state, and hand the continuation state back to shift.
  3. If we’re processing a list in shift, and ret is Nothing(), that means we’re looking at a new list — we need to add a new accumulator to the continuation state.
  4. If we’re processing a list in shift (ret is not Nothing()) and term is Nothing(), we’ve finished processing a list — we should hand ret off to unshift.

We’ll start with handling a new list term.

def deep_walk(term, sub):
    def shift(state: MachineState):
        print("shift: ", state)

        term = state.term.bind(lambda x: Just(walk(x, sub)))
        state = MachineState(term, *state[1:])

        match state:
+            case (Just([a, *d]), cont, Nothing()):
+                # Start processing a list.
+
+                # Create a new continuation state capturing the rest
+                # of the list and a new accumulator.
+                k = MachineState(term=Just(d), cont=cont, ret=Just([]))
+
+                # Process the first sub-term.
+                return shift(MachineState(Just(a), Just(k), Nothing()))
            case (candidate, cont, Nothing()):
                # Process an atomic term.
                return unshift(MachineState(Nothing(), cont, candidate))

    def unshift(state: MachineState):
        print("unshift: ", state)

        match state:
            case (Nothing(), Nothing(), Just(x)):
                # Final state: return a term to the caller.
                return x
+            case (Nothing(), Just(cont), Just(elt)):
+                # We have a continuation, so we're "returning" a sub-term
+                # to its parent list.
+                match cont:
+                    # Restore the continuation, appending sub-term to accumulator.
+                    case (x, k, Just(acc)):
+                        return shift(MachineState(x, k, Just([*acc, elt])))

    return shift(MachineState(Just(term), Nothing(), Nothing()))

Let’s give it a spin.

>>> deep_walk([3], empty_sub())

First, shift receives the initial state:

shift:  MachineState(term=Just([3]), cont=Nothing(), ret=Nothing())

shift captures a continuation where ret contains a new accumulator, and term contains the rest of the list:

shift:  MachineState(term=Just(3), cont=Just(MachineState(term=Just([]), cont=Nothing(), ret=Just([]))), ret=Nothing())

Here’s the continuation state on its own:

MachineState(term=Just([]), cont=Nothing(), ret=Just([]))

unshift receives the state indicating it should return 3 to the accumulator:

unshift:  MachineState(term=Nothing(), cont=Just(...), ret=Just(3))

… and shift receives the continuation with the updated accumulator.

shift:  MachineState(term=Just([]), cont=Nothing(), ret=Just([3]))

However, to abide by the fourth rule, we must indicate the end of the list by setting term to Nothing() in the continuation. Let’s update shift to handle the end of the list. We’ll also update the list case of shift to handle the continued processing of an existing list — where ret is not Nothing(), but the accumulator.

def deep_walk(term, sub):
    def shift(state: MachineState):
        print("shift: ", state)

        term = state.term.bind(lambda x: Just(walk(x, sub)))
        state = MachineState(term, *state[1:])

        match state:
+            case (Nothing(), cont, Just(ret)):
+                # Got end of list — hand it off to unshift to return.
+                return unshift(MachineState(Nothing(), cont, Just(ret)))

-            case (Just([a, *d]), cont, Nothing()):
-                # Start processing a list.
+            case (Just([a, *d]), cont, ret):
+                # Start or continue processing a list.

                # Create a new continuation state capturing the rest
                # of the list and a new accumulator.
-                k = MachineState(term=Just(d), cont=cont, ret=Just([]))
+                # Set term to Nothing if a is the final sub-term.
+                rest = Just(d) if d else Nothing()
+                # Create a new accumulator if we don't already have one.
+                acc = Just([]) if ret == Nothing() else ret
+                k = MachineState(term=rest, cont=cont, ret=acc)

                # Process the first sub-term.
                return shift(MachineState(Just(a), Just(k), Nothing()))
            case (candidate, cont, Nothing()):
                # Process an atomic term.
                return unshift(MachineState(Nothing(), cont, candidate))

    def unshift(state: MachineState):
        ...

    return shift(MachineState(Just(term), Nothing(), Nothing()))

With those changes, we can process arbitrarily nested lists!

>>> deep_walk([1, 2, 3], empty_sub())
[1, 2, 3]
>>> x = Var("x")
>>> deep_walk([1, x, 3], empty_sub().set(x, "foobar"))
[1, 'foobar', 3]
>>> deep_walk([1, [x, []], 3], empty_sub().set(x, "foobar"))
[1, ['foobar', []], 3]
>>> deep_walk([1, [x, []], [[[[3]]]]], empty_sub().set(x, "foobar"))
[1, ['foobar', []], [[[[3]]]]]
>>> deep_walk([x, [x, []], [[[[3]]]]], empty_sub().set(x, "foobar"))
['foobar', ['foobar', []], [[[[3]]]]]
>>> deep_walk([x, [x, []], [[[[[3]]]], x]], empty_sub().set(x, "foobar"))
['foobar', ['foobar', []], [[[[[3]]]], 'foobar']]

Let’s clean it up and add error handling.

def deep_walk(term, sub):
    def shift(state: MachineState):
        term = state.term.bind(lambda x: Just(walk(x, sub)))
        state = MachineState(term, *state[1:])

        match state:
            case (Nothing(), cont, Just(ret)):
                # End of list — return it.
                return unshift(MachineState(Nothing(), cont, Just(ret)))

            case (Just([a, *d]), cont, ret):
                # Start or continue processing a list.

                # Create a new continuation state capturing the rest
                # of the list and a new accumulator.
                # Set term to Nothing if a is the final sub-term.
                rest = Just(d) if d else Nothing()
                acc = Just([]) if ret == Nothing() else ret
                k = MachineState(term=rest, cont=cont, ret=acc)

                # Process the first sub-term.
                return shift(MachineState(Just(a), Just(k), Nothing()))
            case (candidate, cont, Nothing()):
                # Process an atomic term.
                return unshift(MachineState(Nothing(), cont, candidate))
            case _:
                raise ValueError("shift got invalid state")

    def unshift(state: MachineState):
        match state:
            case (Nothing(), Nothing(), Just(x)):
                # Final state: return a term to the caller.
                return x
            case (Nothing(), Just(MachineState(x, k, Just(acc))), Just(elt)):
                # We have a continuation, so we're "returning" a sub-term
                # to its parent list.
                # Restore the continuation, appending sub-term to accumulator.
                return shift(MachineState(x, k, Just([*acc, elt])))
            case _:
                raise ValueError("unshift got invalid state")

    return shift(MachineState(Just(term), Nothing(), Nothing()))

Handling deep recursion

In the opening of this post I promised a way to work around Python’s recursion limit. So far we’ve not achieved that as we make direct recursive calls to both shift and unshift. To make good on our promises, we’ll “trampoline” the implementation. This is straightforward: any time we would return a recursive call to shift or unshift, we wrap that call in a lambda and return it (a “thunk”), and wrap the initial call in a helper that evaluates each thunk that is returned. With this in place, the stack growth is bounded to 2 frames (one for the thunk, one for the call to shift or unshift), and we can handle arbitrarily deep computations.

- def deep_walk(term, sub):
+ def deep_walk_tramp(term, sub):
+    def pull(x):
+        while callable(x):
+            x = x()
+        return x

    def shift(state: MachineState):
        term = state.term.bind(lambda x: Just(walk(x, sub)))
        state = MachineState(term, *state[1:])

        match state:
            case (Nothing(), cont, Just(ret)):
                # End of list — return it.
-                return unshift(MachineState(Nothing(), cont, Just(ret)))
+                return lambda: unshift(MachineState(Nothing(), cont, Just(ret)))

            case (Just([a, *d]), cont, ret):
                # Start or continue processing a list.

                # Create a new continuation state capturing the rest
                # of the list and a new accumulator.
                # Set term to Nothing if a is the final sub-term.
                rest = Just(d) if d else Nothing()
                acc = Just([]) if ret == Nothing() else ret
                k = MachineState(term=rest, cont=cont, ret=acc)

                # Process the first sub-term.
-                return shift(MachineState(Just(a), Just(k), Nothing()))
+                return lambda: shift(MachineState(Just(a), Just(k), Nothing()))
            case (candidate, cont, Nothing()):
                # Process an atomic term.
-                return unshift(MachineState(Nothing(), cont, candidate))
+                return lambda: unshift(MachineState(Nothing(), cont, candidate))
            case _:
                raise ValueError("shift got invalid state")

    def unshift(state: MachineState):
        match state:
            case (Nothing(), Nothing(), Just(x)):
                # Final state: return a term to the caller.
                return x
            case (Nothing(), Just(MachineState(x, k, Just(acc))), Just(elt)):
                # We have a continuation, so we're "returning" a sub-term
                # to its parent list.
                # Restore the continuation, appending term to accumulator.
-                return shift(MachineState(x, k, Just([*acc, elt])))
+                return lambda: shift(MachineState(x, k, Just([*acc, elt])))
            case _:
                raise ValueError("unshift got invalid state")

-    return shift(MachineState(Just(term), Nothing(), Nothing()))
+    return pull(shift(MachineState(Just(term), Nothing(), Nothing())))

Rewriting it iteratively

The number of cases we have to handle is small, and it turns out they are mutually exclusive. We can rewrite deep_walk iteratively without much trouble. In the process we can remove the “end of list” case from shift — it’s redundant.

def deep_walk_iter(term, sub):
    state = MachineState(Just(term), Nothing(), Nothing())

    while True:
        term = state.term.bind(lambda x: Just(walk(x, sub)))
        state = MachineState(term, state.cont, state.ret)

        match state:
            case (Nothing(), Nothing(), Just(x)):
                # Final state: return a term to the caller.
                return x
            case (Nothing(), Just(MachineState(x, k, Just(acc))), Just(elt)):
                # We have a continuation, so we're "returning" a
                # sub-term to its parent list.

                # Restore the continuation, appending term to
                # accumulator.
                 state = MachineState(x, k, Just([*acc, elt]))
                 continue
            case (Just([a, *d]), cont, ret):
                # Start or continue processing a list.

                # Create a new continuation state capturing the rest
                # of the list and a new accumulator.
                # Set term to Nothing if a is the final sub-term.
                term = Just(d) if d else Nothing()
                acc = Just([]) if ret == Nothing() else ret
                k = MachineState(term=term, cont=cont, ret=acc)

                # Process the first sub-term.
                state = MachineState(Just(a), Just(k), Nothing())
                continue
            case (candidate, cont, Nothing()):
                # Process an atomic term.
                state = MachineState(Nothing(), cont, candidate)
                continue
            case _:
                raise ValueError("deep_walk_iter got invalid state")

Instead of making co-recursive function calls, we’re using a while loop, updating the state variable on each iteration until the process is complete.

Performance comparison of recursive, iterative and trampolined approaches

I anticipate that the iterative version of deep walk will be faster due to the reduced function-call overhead. To test this out, we’ll time the execution of a kanren goal that generates a lot of deeply nested lists as it searches for more and more solutions: lambda-term°.

(defn var° [x]
  (== x 'x))

(defne abstraction° [term]
  ([['λ v ': t^]] (var° v)))

(defne application° [term]
  ([[rator rand]] (lambda-term° rator) (lambda-term° rand)))

(defn lambda-term° [term]
  (conde
   [(abstraction° term)]
   [(var° term)]
   [(application° term)]))

Its results look like this:

=> (run 5 [x] (lambda-term° x))
['x
 ['λ 'x : _.0]
 ['x 'x]
 ['x ['λ 'x : _.0]]
 [['λ 'x : _.0] 'x]
]

deep_walk is used in the reification process. We’ll patch it with the different implementations to compare execution times. I ran these simple benchmarks on Python 3.13.0.

First, the trampolined version.

=> (setv (. (get sys.modules "microkanren.core") deep_walk)
         (. (get sys.modules "microkanren.core") deep_walk_tramp))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
40.51025476504583

Then the iterative version.

=> (setv (. (get sys.modules "microkanren.core") deep_walk)
         (. (get sys.modules "microkanren.core") deep_walk_iter))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
33.74411466499441

The iterative version is significantly faster than the trampolined implementation, but it’s slower than the original recursive implementation!

=> (setv (. (get sys.modules "microkanren.core") deep_walk)
... (. (get sys.modules "microkanren.core") deep_walk_recur))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
18.099126554006943

I was able to improve the speed of the iterative version by:

  1. replacing the MachineState namedtuples with plain tuples;
  2. using single-element tuples and the empty tuple in place of Just and Nothing, respectively; and
  3. allocating an accumulator list of the correct length the first time a list term is encountered, and adding an index register to the state so that the accumulator elements can be updated in-place, rather than allocating lots of intermediary lists.
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
... (. (get sys.modules "microkanren.core") deep_walk_iter))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
20.935525032982696

That’s still a little underwhelming, so I reimplemented it using if statements rather than pattern matching and dropped the Maybe concept altogether, using a sentinel object for empty values. I found some discussions online that suggested that the code generated for match statements can have a lot of redundancy, and the disassembled code backs that up (the specifics are a topic for another post). Here’s the final iterative implementation, and timings. It runs faster than the original recursive implementation.

def deep_walk_iter_if(term, sub):
    # term, cont, acc, index
    NOTHING = object()
    state = (term, NOTHING, NOTHING, NOTHING)

    while True:
        term, cont, ret, i = state
        term = walk(term, sub)

        if term == NOTHING:
            if cont == NOTHING and i == NOTHING:
                # Final state: return a term to the caller.
                return ret
            else:
                # We have a continuation, so we're "returning" a
                # sub-term to its parent list.
                x, k, acc, i = cont

                # Restore the continuation, inserting elt into acc at
                # index i.
                acc[i] = ret

                # If the cont term is NOTHING, list processing is
                # done, so drop the index. Otherwise, bump the index
                # so the next time we enter this case we update the
                # next element in the list with its walked sub-term.
                index = i + 1 if x != NOTHING else NOTHING

                state = (x, k, acc, index)
                continue
        elif isinstance(term, list):
            # Start or continue processing a list.
            a, *d = term

            # Create a new continuation state capturing the rest
            # of the list and a new accumulator.
            if i == NOTHING:
                # Allocate a new accumulator.
                length = len(d) + 1
                acc = [NOTHING] * length
                i = 0
            else:
                # Use the existing accumulator and index (the index is
                # bumped in the list-processing case).
                acc = ret

            # Set term to NOTHING if a is the final sub-term.
            term = d or NOTHING
            k = (term, cont, acc, i)
            # Process the first sub-term.
            state = (a, k, NOTHING, NOTHING)
            continue
        else:
            # Process an atomic term.
            state = (NOTHING, cont, term, NOTHING)
            continue
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
... (. (get sys.modules "microkanren.core") deep_walk_iter))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
15.508090497984085

Here’s a comparison of the timings of the different implementations:

Implementation Benchmark (seconds)
deep_walk_tramp 40.51025476504583
deep_walk_iter 20.935525032982696
deep_walk_recur 18.099126554006943
deep_walk_iter_if 15.508090497984085

The benchmarks are not comprehensive, so take them with a grain of salt. Thorough benchmarking would consider each implementation along controlled axes of list length and degree of nested-ness.


In summary, we’ve explored how to transform a recursive function into an abstract machine, and the result is readable (if a little austere). What remains is to add support for more container types, especially tuples and cons pairs, so that I can use the implementation in my microkanren library. Some of the library’s other functions will need similar treatment. One benefit of this iterative, imperative implementation is that I think it will be a fairly direct translation to rust, which is an experiment I’m looking forward to.

Tags: