Model only converges when wrapped in `torch.compile`

Hi,

I have a frustrating bug that I don’t know how to begin to solve. Essentially - if I torch.compile my model, then it trains normally. I get the results I would expect, I can do inference etc. etc.

To speed up my development loop I tried removing the torch.compile, and suddenly my model cannot improve past random performance. I have tested two identical configurations, that differ only in the use of torch.compile.

Does anyone have any idea how this is possible?

Could you post the model definition as well as the training loop, please?

It is quite a lot of code but I can give you a few snippets. (Edit*: I can send more if required, but maybe I will work towards a min repro. first…)

The train loop is something like this:

        if self._loss_fn is None:
            if self.torch_compile:
                print("Compiling loss function...")
                self._loss_fn = torch.compile(self.compute_loss, fullgraph=True, backend="eager")
            else:
                self._loss_fn = self.compute_loss

        for features in enumerate(loader):
            loss, *_ = self._loss_fn(**features)
            loss.backward()
            ...

compute loss looks like:

# forward pass
pad_idx = 1024
nb_codes = self.model.nb_codes
codes_shifted = F.pad(codes, (1, 0), "constant", pad_idx)  # (B, nb_codebooks, L+1)
logits = self.model(
    x=codes_shifted,
    ...
)  # (B, nb_codebooks * nb_codes, L)

loss = F.cross_entropy(logits, codes, ignore_index=pad_idx, reduction="mean")

return loss, {}

Within the model code itself, can you think of any specific ops that are likely to generate a difference in compiled vs un-compiled?

A few extra pieces of info:

  • the model is a deep auto-regressive model. inference works without compilation (I get good samples from the model)
  • compiling backend=eager also converges. So I guess graph capture is changing something (in a way that happens to be beneficial / stabilizing)?

No and I would also assume that the default “eager” mode should already fail or raise warnings in case something odd happens as I would assume torch.compile has more limitations. However, in your use case it seems torch.compile works while the default eager mode works, which is something I haven’t seen before.
Are you using CUDAGraphs in eager mode as it’s not always trivial to make sure the same buffers are used?

I am not familiar with CUDAGraphs, and also fairly new to torch.compile.

So I think the answer is no? I have tested three settings:

  • no torch.compile (training fails)
  • torch.compile() training succeeds
  • torch.compile(…, backend=“eager”) - training succeeds.

Possible I am doing something silly elsewhere, but yeah its really weird behaviour. I’m just going to start stripping things out until compile and non-compiled give the same results.

Will open an issue if I can identify an actual error…

1 Like

Sounds good! Let me know once you have an executable code snippet reproducing the issue.

I have made some progress in narrowing down a cause of the issue, but its still quite mysterious.

In my code I have the following layer-norm, with (B, C, L) ordering:

@torch.jit.script
def layer_norm_no_bias(x, gamma):
    mean = x.mean(dim=1, keepdim=True)  # (B, C, L) -> (B, 1, L)
    var = x.var(dim=1, keepdim=True, unbiased=False)  # (B, C, L) -> (B, 1, L)
    x = (x - mean) / torch.sqrt(var + 1e-6)  # (B, C, L)
    return gamma * x

If I remove the torch.jit.script, my model converges (regardless of the outer torch.compile). However, with the jit.script my model only converges with an outer torch.compile (which I think is taking precedence / negating the jit.script somehow?).

This is pretty weird, because when I compare the scripted vs. non-scripted layer norm in a notebook, the results are basically identical? (sometimes a tiny epsilon difference which seems probably consistent with floating point errors?).

This is really interesting and great debugging! Thanks a lot for sharing the update.
I don’t know how torch.compile would interact with an internal TorchScript module, but note that TorchScript is in maintenance mode the current recommendation would be to use torch.compile (only).
Does the model work fine without @torch.jit.script in eager mode vs. torch.compile?

CC @marksaroufim for viz as it’s an interesting issue.

1 Like

Yes it works in both modes without the torch.jit.script - it actually feels like two bugs:

  1. weird/silent interaction of torch.compile with torch.jit.script
  2. numerical error in backwards pass of torch.jit.script

I opened an issue for the 2nd, with a minimum reproducible that shows jit.script producing the wrong gradients.

I will just avoid torch.jit.script for now.

Not gonna lie this issue was pretty funny when I first saw it, I suspect dynamo recently skipping jit because a few weeks I had to skip jit export here EMFORMER_RNNT not compilable · Issue #106101 · pytorch/pytorch · GitHub

Maybe @bdhirsh has some ideas as to what’s going on

1 Like