Trying to use FSDP.
Can’t get this file to run with torchrun:
torchrun --nproc_per_node 2 Aug_27_v2.py
Keeps saying:
warnings.warn(
Traceback (most recent call last):
File “Aug_27_v2.py”, line 57, in
train(rank, world_size)
File “Aug_27_v2.py”, line 43, in train
outputs = fsdp_model(inputs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py”, line 1440, in forward
outputs = self.module(*args, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/fairscale/nn/misc/flatten_params_wrapper.py”, line 487, in forward
return self.module(*inputs, **kwinputs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “Aug_27_v2.py”, line 16, in forward
return self.fc(x)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)
File “/home/clark/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py”, line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_addmm)
File:
# Aug_27_v2.py
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
return self.fc(x)
def train(rank, world_size):
# Set up the device and allow the environment to decide the specific GPU
device = torch.device(f"cuda:{rank}")
# Initialize the distributed environment
dist.init_process_group(
'nccl',
init_method='env://',
rank=rank,
world_size=world_size
)
# Create model and move to appropriate device
model = SimpleModel().to(device)
fsdp_model = FSDP(model)
# Ensure criterion is also on the correct device
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.01)
inputs = torch.randn(20, 10).to(device)
targets = torch.randn(20, 10).to(device)
for epoch in range(1000): # loop over the dataset multiple times
optimizer.zero_grad()
outputs = fsdp_model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if rank == 0 and epoch % 100 == 0: # Only print from one process to avoid clutter
print('Epoch {}: Loss: {:.4f}'.format(epoch, loss.item()))
dist.destroy_process_group()
if __name__ == '__main__':
world_size = int(os.environ.get('WORLD_SIZE', 1))
rank = int(os.environ.get('RANK', 0))
train(rank, world_size)