Can't get simple FSDP file to work

Trying to use FSDP.

Can’t get this file to run with torchrun:

torchrun --nproc_per_node 2 Aug_27_v2.py

Keeps saying:

warnings.warn(
Traceback (most recent call last):
File “Aug_27_v2.py”, line 57, in
train(rank, world_size)
File “Aug_27_v2.py”, line 43, in train
outputs = fsdp_model(inputs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py”, line 1440, in forward
outputs = self.module(*args, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/fairscale/nn/misc/flatten_params_wrapper.py”, line 487, in forward
return self.module(*inputs, **kwinputs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “Aug_27_v2.py”, line 16, in forward
return self.fc(x)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py”, line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_addmm)

File:

# Aug_27_v2.py

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc(x)

def train(rank, world_size):
    # Set up the device and allow the environment to decide the specific GPU
    device = torch.device(f"cuda:{rank}")

    # Initialize the distributed environment
    dist.init_process_group(
        'nccl',
        init_method='env://',
        rank=rank,
        world_size=world_size
    )

    # Create model and move to appropriate device
    model = SimpleModel().to(device)
    fsdp_model = FSDP(model)

    # Ensure criterion is also on the correct device
    criterion = nn.MSELoss().to(device)
    optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)

    inputs = torch.randn(20, 10).to(device)
    targets = torch.randn(20, 10).to(device)

    for epoch in range(1000):  # loop over the dataset multiple times
        optimizer.zero_grad()
        outputs = fsdp_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        if rank == 0 and epoch % 100 == 0:  # Only print from one process to avoid clutter
            print('Epoch {}: Loss: {:.4f}'.format(epoch, loss.item()))

    dist.destroy_process_group()

if __name__ == '__main__':
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    rank = int(os.environ.get('RANK', 0))

    train(rank, world_size)

Couldn’t find a version that worked so I created one here if anyone is looking: