How to let GAN do in fsdp

#i want to build cyclegan do in fsdp but it got this error that i have no idea how to let fsdp backward loss in same device
self.genX2Y.train()
self.genY2X.train()
self.discX.train()
self.discY.train()
fake_x = self.genY2X(y_img1)
fake_y = self.genX2Y(x_img1)
self.opt_disc.zero_grad()
D_X_real = self.discX(x_img1)
D_X_loss = -torch.mean( D_X_real)
D_Y_real = self.discY(y_img1)
D_Y_loss = -torch.mean( D_Y_real)
D_loss = (D_X_loss + D_Y_loss) / 2

dist.all_reduce(D_loss, op=dist.ReduceOp.SUM)
D_loss.backward(retain_graph=True)

self.opt_gen.zero_grad()

train generators

D_X_fake1 = self.discX(fake_x)
D_Y_fake1 = self.discY(fake_y)
loss_G_X1 = -torch.mean(D_X_fake1)
loss_G_Y1 = -torch.mean(D_Y_fake1)

cycle loss

cycle_Y = self.genX2Y(fake_x)
cycle_X = self.genY2X(fake_y)
cycle_Y_loss1 = self.mse(cycle_Y, y_img1)
cycle_X_loss1 = self.mse(cycle_X, x_img1)

identity loss

id_Y = self.genX2Y(y_img1)
id_X = self.genY2X(x_img1)
id_Y_loss = self.mse(id_Y,x_img1)
id_X_loss = self.mse(id_X,y_img1)

combined loss

LAMBDA_CYCLE = 3
LAMBDA_ID = 1
cycle=(cycle_Y_loss1 + cycle_X_loss1)*LAMBDA_CYCLE
loss_GAN=(loss_G_Y1 + loss_G_X1)/2
loss_identity=(id_Y_loss + id_X_loss)*LAMBDA_ID/2
G_loss = loss_GAN + cycle + loss_identity
dist.all_reduce(G_loss, op=dist.ReduceOp.SUM)
G_loss.backward(retain_graph=True)
self.opt_gen.step()
self.opt_disc.step()

#error messages:
File “/root/notebooks/groups/FSDP_torchrun_0420.py”, line 772, in _run_epoch
fake_x = self.genY2X(y_img1)
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py”, line 783, in forward
args, kwargs = _pre_forward(
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py”, line 414, in _pre_forward
unshard_fn()
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py”, line 439, in _pre_forward_unshard
_unshard(
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py”, line 331, in _unshard
handle.post_unshard()
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/flat_param.py”, line 1249, in post_unshard
self._check_on_compute_device(self.flat_param)
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/flat_param.py”, line 2259, in _check_on_compute_device
_p_assert(
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/utils.py”, line 104, in _p_assert
traceback.print_stack()
Expects tensor to be on the compute device cuda:1
File “/root/notebooks/groups/FSDP_torchrun_0420.py”, line 969, in
main(args.save_every, args.total_epochs, args.batch_size)
File “/root/notebooks/groups/FSDP_torchrun_0420.py”, line 958, in main
trainer.train(total_epochs)
File “/root/notebooks/groups/FSDP_torchrun_0420.py”, line 895, in train
self._run_epoch(epoch)
File “/root/notebooks/groups/FSDP_torchrun_0420.py”, line 772, in _run_epoch
fake_x = self.genY2X(y_img1)
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py”, line 783, in forward
args, kwargs = _pre_forward(
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py”, line 414, in _pre_forward
unshard_fn()
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py”, line 439, in _pre_forward_unshard
_unshard(
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py”, line 331, in _unshard
handle.post_unshard()
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/flat_param.py”, line 1249, in post_unshard
self._check_on_compute_device(self.flat_param)
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/fsdp/flat_param.py”, line 2259, in _check_on_compute_device
_p_assert(
File “/root/notebooks/anaconda3/envs/aienv/lib/python3.11/site-packages/torch/distributed/utils.py”, line 104, in _p_assert
traceback.print_stack()

Hard to answer for sure if you can’t share a fully functioning repo but I suspect this is the meaningful error Expects tensor to be on the compute device cuda:1

Can you make sure all your tensors are on the correct device?