FSDP module crash during backward due to `TrainingState_.IDLE`

I’m experiencing an error that happens when I call loss.backward():

Asserting FSDP instance is: FullyShardedDataParallel(
  (_fsdp_wrapped_module): FlattenParamsWrapper(
    (_fpw_module): Linear(in_features=2, out_features=128, bias=True)
  )
)
ERROR: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE

I’ve quite tried to isolate a smaller snippet but unfortunately couldn’t, as the model is big and it seemed to be related to something that happens during the forward pass. I tried to mock forward outputs and the loss backward would work fine. It’s a bit hard and likely unproductive to dig out which part in the forward pass is causing this.

one interesting thing (not sure if helpful) I found is, the module that is giving error is Linear(in_features=2, out_features=128, bias=True). My model has 3 of them, and _register_pre_backward_hooks is called to register all of them. However, 2 of this module actually runs the hook _pre_backward_hook and succeeds; but the last module _pre_backward_hook didn’t get called even though it’s registered, before _post_backward_hook got called.

I’m on PyTorch 1.13. I’m running backward on one loss. Have we seen this error before?

Do you have a code example which we can repro?

I’ve quite tried but unfortunately couldn’t isolate a smaller snippet, as the model is big and it seemed to be related to something that happens during the forward pass. I tried to mock forward outputs and the loss backward would work fine. It’s a bit hard and likely unproductive to dig out which part in the forward pass is causing this. So I was posting and hoping if anyone could have some high level insights.

Got it. That’s why I am asking if you can let us know how you wrap your model with FSDP and how did you run your forward pass so that we can better help here. Also can you paste the error stack here?

Thanks!

I simply wrap my model with

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
wrapped_model = FSDP(model)
# load data
total_loss = wrapped_model.forward(data)
total_loss.backward()

the forward pass runs forward on 2 models, so something like

def forward(x):
  x1 = wrapped_model.first_net.forward(x)
  x2 = wrapped_model.second_net.forward(x1)
  total_loss = wrapped_model.loss(x1, x2)

this is the stack trace:

".../trainer.py", line 269, in train_epoch
    metas = self.train_step(data)
  File ".../trainer.py", line 309, in train_step
    self.update_params(total_loss, metas)
  File ".../trainer.py", line 320, in update_params
    total_loss.backward()
  File ".../pip_torch/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File ".../pip_torch/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File ".../pip_torch/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File ".../pip_torch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3215, in _post_backward_hook
    self._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST])
  File ".../pip_torch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 3583, in _assert_state
    raise ValueError(msg)
ValueError: expected to be in states [<TrainingState_.BACKWARD_PRE: 3>, <TrainingState_.BACKWARD_POST: 4>] but current state is TrainingState_.IDLE

I am not sure if I fully get your point. So basically, you wrap the big(parent) model with FSDP but you do the forward for each submodule. And then you do the backward.

Is it possible that you wrap each net separately? And do the forward and backward separately?
Why do you need to have the current logic? Do you just need x1 and x2 for loss calculation?

Also cc: @agu for a second opinion here.

the two nets inside the parent model depend on each other, we need the output of the first net’s forward pass (x1) as the second net’s input. But I guess this is what you mean?

class ParentModel:
  def __init__():
     self.first_net = FSDP(first_net)
     self.second_net = FSDP(second_net)
  def forward(x):
    x1 = self.first_net.forward(x)
    x2 = self.second_net.forward(x1)
    total_loss = self.loss(x1, x2)

So I guess the general question is does FSDP support nested models?

OK, now I understand why you are doing that wrapping. Before answering the question about FSDP, may I ask what happens within your loss. Did you just calculate the loss and do the backward inside the loss function? If so how did you call the backward() here?

Also, you don’t need to call .forward(). You can just do:

x1 = self.first_net(x)
x2 = self.second_net(x1)

We do the naive standard loss and backward, we call backward after the entire forward pass. i.e:

# get data x
total_loss = parent_model.forward(x)
total_loss.backward()

Hi, thanks for your reply. The code looks legit to me and after chatting with @agu, we guess that there might be some limitations on the FSDP side. But from the information you gave us, it’s hard to infer what limitations are.

If you can provide a minimal repro, then it will help us prioritize fixing it; otherwise, it may take a while to get around this.

Also have you ever tried to not to call .forward but do what I mentioned above?

Thanks!

Thanks for helping! I’m trying to isolate out the part in the model that is causing this. Although I found the module that FSDP always crashes at, but when trying to run them separately (even using the same data input), it succeeds in this separate snippet:

class Linear(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.linear = nn.Linear(n, n, bias=False)
        self.norm = nn.LayerNorm(n)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        out = self.relu(out)
        return out


class MyModel(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.layer = nn.Sequential(nn.Linear(2, n), nn.ReLU(inplace=True), Linear(n))

    def forward(self, x):
        return self.layer(x)


if __name__ == "__main__":
    init_dist() # helper function that sets all the distributed env variables
    m = MyModel(128).to("cuda")
    model = FSDP(m)
    model = model.train()
    x = torch.load("/tmp/input_saved_from_training.pt")
    out = model(x)
    loss = out.sum()
    loss.backward()

it’s likely some entangled stuff in the model that could be causing this, just curious if you have any tips on debugging, instead of trying to add them in the snippet bit by bit which is quite unproductive?

One thing I noticed was, the Linear module here would finish post forward hook, setting state to IDLE, but never gets to pre backward hook, and got directly jumped to post backward hook, even though the pre backward hooks are registered. What are the possible reasons that a module’s pre-backward hook isn’t triggered?

Thanks for your reply. And that’s exactly why we asked for a minimal repro example. The example you gave here is a working example for FSDP and it’s different from what you have described above (joint loss for example).

Also the reason why the pre backward hooks are not triggered might be our hook registration and interaction with PyTorch autograd somehow went wrong.