I’m experiencing an error that happens when I call loss.backward():
Asserting FSDP instance is: FullyShardedDataParallel(
(_fpw_module): Linear(in_features=2, out_features=128, bias=True)
ERROR: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE
I’ve quite tried to isolate a smaller snippet but unfortunately couldn’t, as the model is big and it seemed to be related to something that happens during the forward pass. I tried to mock forward outputs and the loss backward would work fine. It’s a bit hard and likely unproductive to dig out which part in the forward pass is causing this.
one interesting thing (not sure if helpful) I found is, the module that is giving error is Linear(in_features=2, out_features=128, bias=True). My model has 3 of them, and _register_pre_backward_hooks is called to register all of them. However, 2 of this module actually runs the hook _pre_backward_hook and succeeds; but the last module _pre_backward_hook didn’t get called even though it’s registered, before _post_backward_hook got called.
I’m on PyTorch 1.13. I’m running backward on one loss. Have we seen this error before?
I’ve quite tried but unfortunately couldn’t isolate a smaller snippet, as the model is big and it seemed to be related to something that happens during the forward pass. I tried to mock forward outputs and the loss backward would work fine. It’s a bit hard and likely unproductive to dig out which part in the forward pass is causing this. So I was posting and hoping if anyone could have some high level insights.
".../trainer.py", line 269, in train_epoch
metas = self.train_step(data)
File ".../trainer.py", line 309, in train_step
File ".../trainer.py", line 320, in update_params
File ".../pip_torch/torch/_tensor.py", line 488, in backward
File ".../pip_torch/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File ".../pip_torch/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File ".../pip_torch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3215, in _post_backward_hook
File ".../pip_torch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3583, in _assert_state
ValueError: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE
OK, now I understand why you are doing that wrapping. Before answering the question about FSDP, may I ask what happens within your loss. Did you just calculate the loss and do the backward inside the loss function? If so how did you call the backward() here?
Also, you don’t need to call .forward(). You can just do:
Hi, thanks for your reply. The code looks legit to me and after chatting with @agu, we guess that there might be some limitations on the FSDP side. But from the information you gave us, it’s hard to infer what limitations are.
If you can provide a minimal repro, then it will help us prioritize fixing it; otherwise, it may take a while to get around this.
Also have you ever tried to not to call .forward but do what I mentioned above?
Thanks for helping! I’m trying to isolate out the part in the model that is causing this. Although I found the module that FSDP always crashes at, but when trying to run them separately (even using the same data input), it succeeds in this separate snippet:
def __init__(self, n):
self.linear = nn.Linear(n, n, bias=False)
self.norm = nn.LayerNorm(n)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.linear(x)
out = self.norm(out)
out = self.relu(out)
def __init__(self, n):
self.layer = nn.Sequential(nn.Linear(2, n), nn.ReLU(inplace=True), Linear(n))
def forward(self, x):
if __name__ == "__main__":
init_dist() # helper function that sets all the distributed env variables
m = MyModel(128).to("cuda")
model = FSDP(m)
model = model.train()
x = torch.load("/tmp/input_saved_from_training.pt")
out = model(x)
loss = out.sum()
it’s likely some entangled stuff in the model that could be causing this, just curious if you have any tips on debugging, instead of trying to add them in the snippet bit by bit which is quite unproductive?
One thing I noticed was, the Linear module here would finish post forward hook, setting state to IDLE, but never gets to pre backward hook, and got directly jumped to post backward hook, even though the pre backward hooks are registered. What are the possible reasons that a module’s pre-backward hook isn’t triggered?
Thanks for your reply. And that’s exactly why we asked for a minimal repro example. The example you gave here is a working example for FSDP and it’s different from what you have described above (joint loss for example).
Also the reason why the pre backward hooks are not triggered might be our hook registration and interaction with PyTorch autograd somehow went wrong.