I have a frustrating bug that I don’t know how to begin to solve. Essentially - if I torch.compile my model, then it trains normally. I get the results I would expect, I can do inference etc. etc.
To speed up my development loop I tried removing the torch.compile, and suddenly my model cannot improve past random performance. I have tested two identical configurations, that differ only in the use of torch.compile.
It is quite a lot of code but I can give you a few snippets. (Edit*: I can send more if required, but maybe I will work towards a min repro. first…)
The train loop is something like this:
if self._loss_fn is None:
if self.torch_compile:
print("Compiling loss function...")
self._loss_fn = torch.compile(self.compute_loss, fullgraph=True, backend="eager")
else:
self._loss_fn = self.compute_loss
for features in enumerate(loader):
loss, *_ = self._loss_fn(**features)
loss.backward()
...
No and I would also assume that the default “eager” mode should already fail or raise warnings in case something odd happens as I would assume torch.compile has more limitations. However, in your use case it seems torch.compile works while the default eager mode works, which is something I haven’t seen before.
Are you using CUDAGraphs in eager mode as it’s not always trivial to make sure the same buffers are used?
I am not familiar with CUDAGraphs, and also fairly new to torch.compile.
So I think the answer is no? I have tested three settings:
no torch.compile (training fails)
torch.compile() training succeeds
torch.compile(…, backend=“eager”) - training succeeds.
Possible I am doing something silly elsewhere, but yeah its really weird behaviour. I’m just going to start stripping things out until compile and non-compiled give the same results.
Will open an issue if I can identify an actual error…
If I remove the torch.jit.script, my model converges (regardless of the outer torch.compile). However, with the jit.script my model only converges with an outer torch.compile (which I think is taking precedence / negating the jit.script somehow?).
This is pretty weird, because when I compare the scripted vs. non-scripted layer norm in a notebook, the results are basically identical? (sometimes a tiny epsilon difference which seems probably consistent with floating point errors?).
This is really interesting and great debugging! Thanks a lot for sharing the update.
I don’t know how torch.compile would interact with an internal TorchScript module, but note that TorchScript is in maintenance mode the current recommendation would be to use torch.compile (only).
Does the model work fine without @torch.jit.script in eager mode vs. torch.compile?
CC @marksaroufim for viz as it’s an interesting issue.