I’m getting an error when trying to backprop the norm of the gradients of my model w.r.t. my inputs.
This throws errors such as
[rank0]: Traceback (most recent call last):
[rank0]: File "....../test_autograd_fsdp/min_example.py", line 50, in <module>
[rank0]: main()
[rank0]: File "....../test_autograd_fsdp/min_example.py", line 46, in main
[rank0]: loss.backward()
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: setStorage: sizes [10, 100], strides [1, 10], storage offset 0, and itemsize 4 requiring a storage size of 4000 are out of bounds for storage of size 0
How can I get around this? Below is a minimal example. Thanks!
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 100)
self.fc2 = nn.Linear(100, 100)
self.fc3 = nn.Linear(100, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
def main():
dist.init_process_group(backend='nccl', init_method='env://')
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
model = SimpleNet().to(device)
for name, module in model.named_modules():
if module is not model and any(p.requires_grad for p in module.parameters(recurse=False)):
fully_shard(module)
fully_shard(model)
x = torch.randn(32, 10, device=device, requires_grad=True)
output = model(x)
grads = torch.autograd.grad(
outputs=output.sum(),
inputs=x,
create_graph=True,
retain_graph=True
)[0]
loss = torch.norm(grads, p=2, dim=-1).mean()
loss.backward()
dist.destroy_process_group()
if __name__ == "__main__":
main()