FSDP2 and gradient w.r.t. inputs

I’m getting an error when trying to backprop the norm of the gradients of my model w.r.t. my inputs.

This throws errors such as

[rank0]: Traceback (most recent call last):
[rank0]:   File "....../test_autograd_fsdp/min_example.py", line 50, in <module>
[rank0]:     main()
[rank0]:   File "....../test_autograd_fsdp/min_example.py", line 46, in main
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: setStorage: sizes [10, 100], strides [1, 10], storage offset 0, and itemsize 4 requiring a storage size of 4000 are out of bounds for storage of size 0

How can I get around this? Below is a minimal example. Thanks!

import os

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

def main():
    dist.init_process_group(backend='nccl', init_method='env://')
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    model = SimpleNet().to(device)
    for name, module in model.named_modules():
        if module is not model and any(p.requires_grad for p in module.parameters(recurse=False)):
            fully_shard(module)
    fully_shard(model)

    x = torch.randn(32, 10, device=device, requires_grad=True)

    output = model(x)

    grads = torch.autograd.grad(
        outputs=output.sum(),
        inputs=x,
        create_graph=True,
        retain_graph=True
    )[0]


    loss = torch.norm(grads, p=2, dim=-1).mean()
    loss.backward()
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

cc. @weifengpy on FSDP2 related issue