Skip a submodule that cannot be compile

Hey all,

I am using xformers to perform fast scale dot product and the function is not compatible with torch.compile.

Though it is the only place in my code that is not compatible. I would like to be able to call torch.compile() on my model and be able to just “skip” the part that is calling xformers function from compilation.

Right now to achieve the same I have to apply torch.compile manually on every layer that is not the xformers one. Is there a way to achieve the same in an automatic matter ?

Thanks in advance

Depending on your exact use case it sounds like you might want torch._dynamo.disable or torch._dynamo.disallow_in_graph. See here: Frequently Asked Questions — PyTorch 2.1 documentation

It is exactly what I was looking for. Thanks !

By any chance do you know why it is a private function here ? Does it just mean it is experimental for now and will be stabilize later ?

Great! I’m not sure how stable these methods are but the fact that they’re documented implies to me that they’re reasonably safe to use.