I am using xformers to perform fast scale dot product and the function is not compatible with torch.compile.
Though it is the only place in my code that is not compatible. I would like to be able to call torch.compile() on my model and be able to just “skip” the part that is calling xformers function from compilation.
Right now to achieve the same I have to apply torch.compile manually on every layer that is not the xformers one. Is there a way to achieve the same in an automatic matter ?