Rewriting a recursive Python function as an abstract machine

The miniKanren paper (Will Byrd’s thesis) describes walk*, which is used to deeply walk some term with regards to a substitution. That means that if a given term is not atomic (e.g. it’s a list), it will walk all of its sub-terms, recursively. Here’s a naive translation to Python:
def deep_walk(candidate: Any, sub: Substitution) -> Any:
"""
Like `walk', but process elements of lists.
"""
candidate = walk(candidate, sub)
if isinstance(candidate, list):
return [deep_walk(x, sub) for x in candidate]
else:
return candidate
The problem with this recursive approach is that we’re bound by Python’s recursion limit. Terms may be nested arbitrarily deeply, so the recursive calls to deep_walk may blow the stack. When facing this kind of problem, the general advice from Pythonistas is to “rewrite it iteratively”. Let’s take a look at how we might approach that.
One approach is to rewrite the function with an explicit stack that you manually manage, emulating function calls with pushes, and returns with pops. However, the beauty of the recursive approach is that it hides all of that book-keeping away, leaving just a description of the behaviour we really care about. The recursive implementation succinctly describes what it means for a term to be “deeply walked” — it would be a shame to lose that.
Instead, let’s write it as an abstract machine. This won’t absolve us from all book-keeping, and the result will not be nearly as self-descriptive as the recursive approach, but it will keep the focus on the finite states of the program that encapsulate the desired behaviour.
Note (2024-11-26): the context of this post is logic programming and the implementation of miniKanren style languages. However, the content is about programming in general, so I won’t spend any time explaining logic programming specifics.
Defining an abstract machine
While working on this problem, I was reminded of Jean-Louis Krivine’s paper A call-by-name lambda-calculus machine, which I found while trying to grok justine’s Lambda Calculus in 383 Bytes. It describes an abstract machine for lazy evaluation of lambda calculus terms, which uses only three registers and three operations. We won’t achieve anything near that level of bang-for-buck, but it’s inspiring nevertheless. In particular, the way closures are handled is similar to the way we’ll handle continuations.
We’ll start by building a machine with 3 registers. At any given time in the machine’s life, each register might be holding a value, or it might be empty. We can’t use None or False to represent an empty register, as those are valid terms. We could use a sentinel value, or add another register that describes the machine’s current state, but I like the challenge of keeping the program pure, and minimising the amount of state we capture. Instead, we’ll borrow from functional programming land, and use a Maybe type, defined as follows:
T = TypeVar("T")
U = TypeVar("U")
class Just[T]:
val: T
__match_args__ = ("val",)
def __init__(self, val: T):
self.val = val
def __repr__(self) -> str:
return f"Just({self.val!r})"
def bind(self, f: Callable[[T], "Maybe[U]"]) -> "Maybe[U]":
return f(self.val)
class Nothing:
def __eq__(self, other):
return type(other) is self.__class__
def __repr__(self) -> str:
return "Nothing()"
def bind(self, _):
return Nothing()
type Maybe[T] = Just[T] | Nothing
This will work nicely with Python’s structural pattern matching.
Next, we’ll define a Python class that represents the state of the machine.
class MachineState(NamedTuple):
term: Maybe[Any] = Nothing()
cont: Maybe["MachineState"] = Nothing()
ret: Maybe[Any] = Nothing()
The three registers are:
term— the term currently being evaluated;cont— the current continuation (the suspended state we’ll return to after the current computation); andret- the term to return to the continuation (or the final return value of the program).
There are 6 states that the machine needs to handle (which we’ll reduce to 5 cases in the code):
- we’re looking at an atomic term;
- we’re looking at a new list;
- we’re continuing processing a list;
- we’ve finished processing a list;
- we’ve finished processing a sub-term of a list; and
- final state — ready to return a term to the caller.
We’ll implement the machine as two mutually recursive functions: shift (which will process states 1 to 4) and unshift (which will process states 5 and 6).
Processing atomic values
The simplest case is accepting an atomic term and returning it to the caller.
+def deep_walk(term, sub):
+
+ def shift(state: MachineState):
+ print("shift: ", state)
+
+ # Do a shallow walk first — if the term is a bound logic variable,
+ # we'll process its associated term.
+ # bind will operate on the wrapped value of Just instances,
+ # and return Nothing() for Nothings.
+ term = state.term.bind(lambda x: Just(walk(x, sub)))
+ state = MachineState(term, state.cont, state.ret)
+
+ match state:
+ case (term, cont, Nothing()):
+ # Shift the atomic term to the return position.
+ return unshift(MachineState(Nothing(), cont, term))
+
+ def unshift(state: MachineState):
+ print("unshift: ", state)
+
+ match state:
+ case (Nothing(), Nothing(), Just(term)):
+ # Final state: return to caller.
+ return term
+
+ return shift(MachineState(Just(term), Nothing(), Nothing()))
Let’s test it out.
>>> deep_walk(3, empty_sub())
shift: MachineState(term=Just(3), cont=Nothing(), ret=Nothing())
unshift: MachineState(term=Nothing(), cont=Nothing(), ret=Just(3))
3
Atomic terms are processed and returned to the caller. This is implemented by taking the value in the term register, shifting it to the ret register, and handing the state off to unshift.
>>> x = Var("x")
>>> deep_walk(x, empty_sub().set(x, "foo"))
shift: MachineState(term=Just('foo'), cont=Nothing(), ret=Nothing())
unshift: MachineState(term=Nothing(), cont=Nothing(), ret=Just('foo'))
'foo'
Bound variables are walked.
>>> deep_walk([x], empty_sub().set(x, "foo"))
shift: MachineState(term=Just([Var('x')]), cont=Nothing(), ret=Nothing())
unshift: MachineState(term=Nothing(), cont=Nothing(), ret=Just([Var('x')]))
[Var('x')]
Sub terms are not processed, yet. We’ll deal with that next.
Processing compound terms
Processing compound terms (lists, for the sake of this example) is slightly more complex. We need to:
- suspend the current evaluation while we process a sub-term;
- evaluate a sub-term while a continuation is waiting;
- resume the suspended evaluation, adding the evaluated sub-term to the output list; and
- handle the end of a list, returning that list to the caller or adding it to an accumulator if it’s part of a nested list.
We need to describe each of these states using our 3 registers, so it will be helpful to establish some rules.
- Whenever we process a sub-term of a list, we push a continuation that captures the rest of the list in the
termregister, and an accumulator for sub-terms that have already been processed (stored in theretregister). - Whenever we
unshift, and the continuation is notNothing()append theretvalue of the current state to the accumulator of the continuation state, and hand the continuation state back toshift. - If we’re processing a list in
shift, andretisNothing(), that means we’re looking at a new list — we need to add a new accumulator to the continuation state. - If we’re processing a list in
shift(retis notNothing()) andtermisNothing(), we’ve finished processing a list — we should handretoff tounshift.
We’ll start with handling a new list term.
def deep_walk(term, sub):
def shift(state: MachineState):
print("shift: ", state)
term = state.term.bind(lambda x: Just(walk(x, sub)))
state = MachineState(term, *state[1:])
match state:
+ case (Just([a, *d]), cont, Nothing()):
+ # Start processing a list.
+
+ # Create a new continuation state capturing the rest
+ # of the list and a new accumulator.
+ k = MachineState(term=Just(d), cont=cont, ret=Just([]))
+
+ # Process the first sub-term.
+ return shift(MachineState(Just(a), Just(k), Nothing()))
case (candidate, cont, Nothing()):
# Process an atomic term.
return unshift(MachineState(Nothing(), cont, candidate))
def unshift(state: MachineState):
print("unshift: ", state)
match state:
case (Nothing(), Nothing(), Just(x)):
# Final state: return a term to the caller.
return x
+ case (Nothing(), Just(cont), Just(elt)):
+ # We have a continuation, so we're "returning" a sub-term
+ # to its parent list.
+ match cont:
+ # Restore the continuation, appending sub-term to accumulator.
+ case (x, k, Just(acc)):
+ return shift(MachineState(x, k, Just([*acc, elt])))
return shift(MachineState(Just(term), Nothing(), Nothing()))
Let’s give it a spin.
>>> deep_walk([3], empty_sub())
First, shift receives the initial state:
shift: MachineState(term=Just([3]), cont=Nothing(), ret=Nothing())
shift captures a continuation where ret contains a new accumulator, and term contains the rest of the list:
shift: MachineState(term=Just(3), cont=Just(MachineState(term=Just([]), cont=Nothing(), ret=Just([]))), ret=Nothing())
Here’s the continuation state on its own:
MachineState(term=Just([]), cont=Nothing(), ret=Just([]))
unshift receives the state indicating it should return 3 to the accumulator:
unshift: MachineState(term=Nothing(), cont=Just(...), ret=Just(3))
… and shift receives the continuation with the updated accumulator.
shift: MachineState(term=Just([]), cont=Nothing(), ret=Just([3]))
However, to abide by the fourth rule, we must indicate the end of the list by setting term to Nothing() in the continuation. Let’s update shift to handle the end of the list. We’ll also update the list case of shift to handle the continued processing of an existing list — where ret is not Nothing(), but the accumulator.
def deep_walk(term, sub):
def shift(state: MachineState):
print("shift: ", state)
term = state.term.bind(lambda x: Just(walk(x, sub)))
state = MachineState(term, *state[1:])
match state:
+ case (Nothing(), cont, Just(ret)):
+ # Got end of list — hand it off to unshift to return.
+ return unshift(MachineState(Nothing(), cont, Just(ret)))
- case (Just([a, *d]), cont, Nothing()):
- # Start processing a list.
+ case (Just([a, *d]), cont, ret):
+ # Start or continue processing a list.
# Create a new continuation state capturing the rest
# of the list and a new accumulator.
- k = MachineState(term=Just(d), cont=cont, ret=Just([]))
+ # Set term to Nothing if a is the final sub-term.
+ rest = Just(d) if d else Nothing()
+ # Create a new accumulator if we don't already have one.
+ acc = Just([]) if ret == Nothing() else ret
+ k = MachineState(term=rest, cont=cont, ret=acc)
# Process the first sub-term.
return shift(MachineState(Just(a), Just(k), Nothing()))
case (candidate, cont, Nothing()):
# Process an atomic term.
return unshift(MachineState(Nothing(), cont, candidate))
def unshift(state: MachineState):
...
return shift(MachineState(Just(term), Nothing(), Nothing()))
With those changes, we can process arbitrarily nested lists!
>>> deep_walk([1, 2, 3], empty_sub())
[1, 2, 3]
>>> x = Var("x")
>>> deep_walk([1, x, 3], empty_sub().set(x, "foobar"))
[1, 'foobar', 3]
>>> deep_walk([1, [x, []], 3], empty_sub().set(x, "foobar"))
[1, ['foobar', []], 3]
>>> deep_walk([1, [x, []], [[[[3]]]]], empty_sub().set(x, "foobar"))
[1, ['foobar', []], [[[[3]]]]]
>>> deep_walk([x, [x, []], [[[[3]]]]], empty_sub().set(x, "foobar"))
['foobar', ['foobar', []], [[[[3]]]]]
>>> deep_walk([x, [x, []], [[[[[3]]]], x]], empty_sub().set(x, "foobar"))
['foobar', ['foobar', []], [[[[[3]]]], 'foobar']]
Let’s clean it up and add error handling.
def deep_walk(term, sub):
def shift(state: MachineState):
term = state.term.bind(lambda x: Just(walk(x, sub)))
state = MachineState(term, *state[1:])
match state:
case (Nothing(), cont, Just(ret)):
# End of list — return it.
return unshift(MachineState(Nothing(), cont, Just(ret)))
case (Just([a, *d]), cont, ret):
# Start or continue processing a list.
# Create a new continuation state capturing the rest
# of the list and a new accumulator.
# Set term to Nothing if a is the final sub-term.
rest = Just(d) if d else Nothing()
acc = Just([]) if ret == Nothing() else ret
k = MachineState(term=rest, cont=cont, ret=acc)
# Process the first sub-term.
return shift(MachineState(Just(a), Just(k), Nothing()))
case (candidate, cont, Nothing()):
# Process an atomic term.
return unshift(MachineState(Nothing(), cont, candidate))
case _:
raise ValueError("shift got invalid state")
def unshift(state: MachineState):
match state:
case (Nothing(), Nothing(), Just(x)):
# Final state: return a term to the caller.
return x
case (Nothing(), Just(MachineState(x, k, Just(acc))), Just(elt)):
# We have a continuation, so we're "returning" a sub-term
# to its parent list.
# Restore the continuation, appending sub-term to accumulator.
return shift(MachineState(x, k, Just([*acc, elt])))
case _:
raise ValueError("unshift got invalid state")
return shift(MachineState(Just(term), Nothing(), Nothing()))
Handling deep recursion
In the opening of this post I promised a way to work around Python’s recursion limit. So far we’ve not achieved that as we make direct recursive calls to both shift and unshift. To make good on our promises, we’ll “trampoline” the implementation. This is straightforward: any time we would return a recursive call to shift or unshift, we wrap that call in a lambda and return it (a “thunk”), and wrap the initial call in a helper that evaluates each thunk that is returned. With this in place, the stack growth is bounded to 2 frames (one for the thunk, one for the call to shift or unshift), and we can handle arbitrarily deep computations.
- def deep_walk(term, sub):
+ def deep_walk_tramp(term, sub):
+ def pull(x):
+ while callable(x):
+ x = x()
+ return x
def shift(state: MachineState):
term = state.term.bind(lambda x: Just(walk(x, sub)))
state = MachineState(term, *state[1:])
match state:
case (Nothing(), cont, Just(ret)):
# End of list — return it.
- return unshift(MachineState(Nothing(), cont, Just(ret)))
+ return lambda: unshift(MachineState(Nothing(), cont, Just(ret)))
case (Just([a, *d]), cont, ret):
# Start or continue processing a list.
# Create a new continuation state capturing the rest
# of the list and a new accumulator.
# Set term to Nothing if a is the final sub-term.
rest = Just(d) if d else Nothing()
acc = Just([]) if ret == Nothing() else ret
k = MachineState(term=rest, cont=cont, ret=acc)
# Process the first sub-term.
- return shift(MachineState(Just(a), Just(k), Nothing()))
+ return lambda: shift(MachineState(Just(a), Just(k), Nothing()))
case (candidate, cont, Nothing()):
# Process an atomic term.
- return unshift(MachineState(Nothing(), cont, candidate))
+ return lambda: unshift(MachineState(Nothing(), cont, candidate))
case _:
raise ValueError("shift got invalid state")
def unshift(state: MachineState):
match state:
case (Nothing(), Nothing(), Just(x)):
# Final state: return a term to the caller.
return x
case (Nothing(), Just(MachineState(x, k, Just(acc))), Just(elt)):
# We have a continuation, so we're "returning" a sub-term
# to its parent list.
# Restore the continuation, appending term to accumulator.
- return shift(MachineState(x, k, Just([*acc, elt])))
+ return lambda: shift(MachineState(x, k, Just([*acc, elt])))
case _:
raise ValueError("unshift got invalid state")
- return shift(MachineState(Just(term), Nothing(), Nothing()))
+ return pull(shift(MachineState(Just(term), Nothing(), Nothing())))
Rewriting it iteratively
The number of cases we have to handle is small, and it turns out they are mutually exclusive. We can rewrite deep_walk iteratively without much trouble. In the process we can remove the “end of list” case from shift — it’s redundant.
def deep_walk_iter(term, sub):
state = MachineState(Just(term), Nothing(), Nothing())
while True:
term = state.term.bind(lambda x: Just(walk(x, sub)))
state = MachineState(term, state.cont, state.ret)
match state:
case (Nothing(), Nothing(), Just(x)):
# Final state: return a term to the caller.
return x
case (Nothing(), Just(MachineState(x, k, Just(acc))), Just(elt)):
# We have a continuation, so we're "returning" a
# sub-term to its parent list.
# Restore the continuation, appending term to
# accumulator.
state = MachineState(x, k, Just([*acc, elt]))
continue
case (Just([a, *d]), cont, ret):
# Start or continue processing a list.
# Create a new continuation state capturing the rest
# of the list and a new accumulator.
# Set term to Nothing if a is the final sub-term.
term = Just(d) if d else Nothing()
acc = Just([]) if ret == Nothing() else ret
k = MachineState(term=term, cont=cont, ret=acc)
# Process the first sub-term.
state = MachineState(Just(a), Just(k), Nothing())
continue
case (candidate, cont, Nothing()):
# Process an atomic term.
state = MachineState(Nothing(), cont, candidate)
continue
case _:
raise ValueError("deep_walk_iter got invalid state")
Instead of making co-recursive function calls, we’re using a while loop, updating the state variable on each iteration until the process is complete.
Performance comparison of recursive, iterative and trampolined approaches
I anticipate that the iterative version of deep walk will be faster due to the reduced function-call overhead. To test this out, we’ll time the execution of a kanren goal that generates a lot of deeply nested lists as it searches for more and more solutions: lambda-term°.
(defn var° [x]
(== x 'x))
(defne abstraction° [term]
([['λ v ': t^]] (var° v)))
(defne application° [term]
([[rator rand]] (lambda-term° rator) (lambda-term° rand)))
(defn lambda-term° [term]
(conde
[(abstraction° term)]
[(var° term)]
[(application° term)]))
Its results look like this:
=> (run 5 [x] (lambda-term° x))
['x
['λ 'x : _.0]
['x 'x]
['x ['λ 'x : _.0]]
[['λ 'x : _.0] 'x]
]
deep_walk is used in the reification process. We’ll patch it with the different implementations to compare execution times. I ran these simple benchmarks on Python 3.13.0.
First, the trampolined version.
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
(. (get sys.modules "microkanren.core") deep_walk_tramp))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
40.51025476504583
Then the iterative version.
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
(. (get sys.modules "microkanren.core") deep_walk_iter))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
33.74411466499441
The iterative version is significantly faster than the trampolined implementation, but it’s slower than the original recursive implementation!
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
... (. (get sys.modules "microkanren.core") deep_walk_recur))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
18.099126554006943
I was able to improve the speed of the iterative version by:
- replacing the
MachineStatenamedtuples with plain tuples; - using single-element tuples and the empty tuple in place of
JustandNothing, respectively; and - allocating an accumulator list of the correct length the first time a list term is encountered, and adding an index register to the state so that the accumulator elements can be updated in-place, rather than allocating lots of intermediary lists.
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
... (. (get sys.modules "microkanren.core") deep_walk_iter))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
20.935525032982696
That’s still a little underwhelming, so I reimplemented it using if statements rather than pattern matching and dropped the Maybe concept altogether, using a sentinel object for empty values. I found some discussions online that suggested that the code generated for match statements can have a lot of redundancy, and the disassembled code backs that up (the specifics are a topic for another post). Here’s the final iterative implementation, and timings. It runs faster than the original recursive implementation.
def deep_walk_iter_if(term, sub):
# term, cont, acc, index
NOTHING = object()
state = (term, NOTHING, NOTHING, NOTHING)
while True:
term, cont, ret, i = state
term = walk(term, sub)
if term == NOTHING:
if cont == NOTHING and i == NOTHING:
# Final state: return a term to the caller.
return ret
else:
# We have a continuation, so we're "returning" a
# sub-term to its parent list.
x, k, acc, i = cont
# Restore the continuation, inserting elt into acc at
# index i.
acc[i] = ret
# If the cont term is NOTHING, list processing is
# done, so drop the index. Otherwise, bump the index
# so the next time we enter this case we update the
# next element in the list with its walked sub-term.
index = i + 1 if x != NOTHING else NOTHING
state = (x, k, acc, index)
continue
elif isinstance(term, list):
# Start or continue processing a list.
a, *d = term
# Create a new continuation state capturing the rest
# of the list and a new accumulator.
if i == NOTHING:
# Allocate a new accumulator.
length = len(d) + 1
acc = [NOTHING] * length
i = 0
else:
# Use the existing accumulator and index (the index is
# bumped in the list-processing case).
acc = ret
# Set term to NOTHING if a is the final sub-term.
term = d or NOTHING
k = (term, cont, acc, i)
# Process the first sub-term.
state = (a, k, NOTHING, NOTHING)
continue
else:
# Process an atomic term.
state = (NOTHING, cont, term, NOTHING)
continue
=> (setv (. (get sys.modules "microkanren.core") deep_walk)
... (. (get sys.modules "microkanren.core") deep_walk_iter))
=> (timeit (fn [] (run 900 [x] (lambda-term° x))) :globals (globals) :number 100)
15.508090497984085
Here’s a comparison of the timings of the different implementations:
| Implementation | Benchmark (seconds) |
|---|---|
deep_walk_tramp |
40.51025476504583 |
deep_walk_iter |
20.935525032982696 |
deep_walk_recur |
18.099126554006943 |
deep_walk_iter_if |
15.508090497984085 |
The benchmarks are not comprehensive, so take them with a grain of salt. Thorough benchmarking would consider each implementation along controlled axes of list length and degree of nested-ness.
In summary, we’ve explored how to transform a recursive function into an abstract machine, and the result is readable (if a little austere). What remains is to add support for more container types, especially tuples and cons pairs, so that I can use the implementation in my microkanren library. Some of the library’s other functions will need similar treatment. One benefit of this iterative, imperative implementation is that I think it will be a fairly direct translation to rust, which is an experiment I’m looking forward to.