Closures are being gc'd and causing failures to compile

Hi!

I’ve been contributing to pytensor adding the pytorch backend. The library generates a symbolic graph, and then

  1. Calls a series of specializations per operation
  2. Glues them together with some generated code.

An example of the generated code

def pytorch_funcified_fgraph(x, y, z):
    # OpFromGraph{inline=False}(y, z)
    tensor_variable_2, tensor_variable_3 = pytorch_funcified_fgraph(y, z)
    # OpFromGraph{inline=False}(x, OpFromGraph{inline=False}.0)
    tensor_variable_5 = pytorch_funcified_fgraph_1(x, tensor_variable_2)
    # True_div(OpFromGraph{inline=False}.0, OpFromGraph{inline=False}.1)
    tensor_variable_6 = elemwise_fn_3(tensor_variable_5, tensor_variable_3)
    return (tensor_variable_6,)

and example of the dispatches

@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
    scalar_op = op.scalar_op
    base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

    def elemwise_fn(*inputs):
        Elemwise._check_runtime_broadcast(node, inputs)
        return base_fn(*inputs)

    return elemwise_fn

When all is said and done, you have a python function that glues all the dispatches together with some tensors. Torch wasn’t able to compile because it would fail to create a guard. The guard’s would fail to be created with this error

[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR] Name: "G['__import_pytensor_dot_link_dot_utils'].elemwise_fn.__closure__[1].cell_contents"
[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR]     Source: global
[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR]     Create Function: TYPE_MATCH
[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR]     Guard Types: None
[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR]     Code List: None
[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR]     Object Weakref: None
[2024-11-22 10:29:01,151] [44/0] torch._guards: [ERROR]     Guarded Class Weakref: None

The only way for us to get around this problem was either too disable the “leaf” nodes (in case of the error above, elemwise_fn would need to have a torch.compiler.disable decorator), run with supress_errors = True; or what I ended up doing in this PR, which is explicitly call setattr on the pytensor.link.utils to be the return value of the dispatch function.

I did some digging through the compile logs, but I couldn’t quite pin down exactly what was happening at guard creation time. I could see that the generated code was doing a pretty good job at inlining all the necessary calls (it would look like this)

class GraphModule(torch.nn.Module):
    def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor):
        input_4 = L_x_
        input_2 = L_y_
        input_3 = L_z_
        
        # File: /Users/ischweer/dev/pytensor/pytensor/link/pytorch/dispatch/elemwise.py:29, code: return base_fn(*inputs)
        tensor_variable = torch.add(input_4, input_2);  input_2 = None
        
        # File: /Users/ischweer/dev/pytensor/pytensor/link/pytorch/dispatch/elemwise.py:29, code: return base_fn(*inputs)
        tensor_variable_1 = torch.add(input_3, input_4);  input_3 = input_4 = None
        return (tensor_variable, tensor_variable_1)

But my best explanation for why the guards were failing is because of something being GC’d (but it could still be ran in the python runtime for…some reason), or pytorch had the module location for these methods. Is there any help I could get on narrowing down the issue?

I’m also attempting to make an MWE without pytensor to see if I can replicate this outside the library. It’s a little tricky though :slight_smile:

CC @ezyang in case you would have an idea about the GC behavior.