I have a scenario where I need to dynamically decorate recursive calls within a function in Python. The key requirement is to achieve this dynamically without modifying the function in the current scope. Let me explain the situation and what I've tried so far.
Suppose I have a function traverse_tree that recursively traverses a binary tree and yields its values. Now, I want to decorate the recursive calls within this function to include additional information, such as the recursion depth. When I use a decorator directly with the function, it works as expected. However, I want to achieve the same dynamically, without modifying the function in the current scope.
import functools
class Node:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def generate_tree():
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)
root.right.left = Node(6)
root.right.right = Node(7)
return root
def with_recursion_depth(func):
"""Yield recursion depth alongside original values of an iterator."""
class Depth(int): pass
depth = Depth(-1)
def depth_in_value(value, depth) -> bool:
return isinstance(value, tuple) and len(value) == 2 and value[-1] is depth
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal depth
depth = Depth(depth + 1)
for value in func(*args, **kwargs):
if depth_in_value(value, depth):
yield value
else:
yield value, depth
depth = Depth(depth - 1)
return wrapper
# 1. using @-syntax
@with_recursion_depth
def traverse_tree(node):
"""Recursively yield values of the binary tree."""
yield node.value
if node.left:
yield from traverse_tree(node.left)
if node.right:
yield from traverse_tree(node.right)
root = generate_tree()
for item in traverse_tree(root):
print(item)
# Output:
# (1, 0)
# (2, 1)
# (4, 2)
# (5, 2)
# (3, 1)
# (6, 2)
# (7, 2)
# 2. Dynamically:
def traverse_tree(node):
"""Recursively yield values of the binary tree."""
yield node.value
if node.left:
yield from traverse_tree(node.left)
if node.right:
yield from traverse_tree(node.right)
root = generate_tree()
for item in with_recursion_depth(traverse_tree)(root):
print(item)
# Output:
# (1, 0)
# (2, 0)
# (4, 0)
# (5, 0)
# (3, 0)
# (6, 0)
# (7, 0)
It seems that the issue lies in how the recursive calls within the function are decorated. When using the decorator dynamically it only decorates the outer function calls and not the recursive calls made within the function. I can achieve this by re-assigning (traverse_tree = with_recursion_depth(traverse_tree)), but now the function has been modified in the current scope. I would like to achieve this dynamically so I can either use the non-decorated function, or optionally wrap it to obtain information on the recursion depth.
I prefer to keep things simple and would like to avoid techniques like bytecode manipulation if there are alternative solutions. However, if that's the necessary path, I'm willing to explore it. I've made an attempt in that direction, but I haven't been successful yet.
import ast
def modify_recursive_calls(func, decorator):
def decorate_recursive_calls(node):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == func.__name__:
func_name = ast.copy_location(ast.Name(id=node.func.id, ctx=ast.Load()), node.func)
decorated_func = ast.Call(
func=ast.Name(id=decorator.__name__, ctx=ast.Load()),
args=[func_name],
keywords=[],
)
node.func = decorated_func
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
decorate_recursive_calls(item)
elif isinstance(value, ast.AST):
decorate_recursive_calls(value)
tree = ast.parse(inspect.getsource(func))
decorate_recursive_calls(tree)
ast.fix_missing_locations(tree)
modified_code = compile(tree, filename="<ast>", mode="exec")
modified_function = types.FunctionType(modified_code.co_consts[1], func.__globals__)
return modified_function
traverse_treeis in the global scope.One thing you could do is patch whatever is in the global scope, although, it might get messytraverse_treegets rebound towrapper, not just thatwrappergets executed.