Dataparallel problem,how to distribute model's output on diferent devices

Hi,I want to use dataparallel.But I come across this problem.I usr forward to get my output ,then I use fake_reply = fake_reply.to(device)(fake_reply is the output ).
But I find that the output is not distributed on device 0 and 1(I set CUDA_VISIBLE_DEVICES=0,1 ).

class D_FullModel(nn.Module):
    def __init__(self, model, gen,D_loss):
        super(D_FullModel, self).__init__()
        self.model = model
        self.gen = gen
        self.loss = D_loss

    def forward(self, targets, inputs):
        #loss_real = loss(real_r, real_labels)


        fake_labels = torch.from_numpy(np.random.uniform(0, 0.3, size=(BATCH_SIZE))).float().to(device)
        real_labels = torch.from_numpy(np.random.uniform(0.7, 1.2, size=(BATCH_SIZE))).float().to(device)

        fake_reply, _, _ =self.gen.sample(inputs, targets)
        fake_reply = fill_with_padding(fake_reply, EOU, PAD).detach()
        fake_reply = fake_reply.to(device)
        print(fake_reply)
        real_r = self.model.batchClassify(targets, inputs)
        fake_r = self.model.batchClassify(fake_reply, inputs)
        print(fake_r)


        x = torch.cat((fake_r, real_r), 0)
        y = torch.cat((fake_labels, real_labels), 0)
        loss = self.loss(x, y)
        return torch.unsqueeze(loss, 0), fake_reply


def D_DataParallel_withLoss(model, gen, D_loss):
    model = D_FullModel(model, gen, D_loss)
    model = torch.nn.DataParallel(model).to(device)
    return model

I found that fake_reply = fake_reply.to(device) is all on the device(‘0’).and inputs is distributed on device 0 and 1.Then the fake_r = self.model.batchClassify(fake_reply, inputs)doesn`t work.
How should I change the code to make the loss function work?
My code immitated the code in https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/21,solution,

import torch
import torch.nn as nn

class FullModel(nn.Module):
  def __init__(self, model, loss):
    super(FullModel, self).__init__()
    self.model = model
    self.loss = loss

  def forward(self, targets, *inputs):
    outputs = self.model(*inputs)
    loss = self.loss(outputs, targets)
    return torch.unsqueeze(loss,0),outputs
    

def DataParallel_withLoss(model,loss,**kwargs):
    model=FullModel(model, loss)
    if 'device_ids' in kwargs.keys():
        device_ids=kwargs['device_ids']
    else:
        device_ids=None
    if 'output_device' in kwargs.keys():
        output_device=kwargs['output_device']
    else:
        output_device=None
    if 'cuda' in kwargs.keys():
        cudaID=kwargs['cuda'] 
        model=torch.nn.DataParallel(model, device_ids=device_ids, output_device=output_device).cuda(cudaID)
    else:
        model=torch.nn.DataParallel(model, device_ids=device_ids, output_device=output_device).cuda()
    return model
class toy(nn.Module):
    def __init__(self):
        super(toy, self).__init__()
        self.conv2d = torch.nn.Conv2d(1,3,1)
    def forward(self,x):
        return self.conv2d(x)
model = toy()
optimizer = torch.optim.SGD(model.parameters(),lr=1)
loss = torch.nn.L1Loss()
model = DataParallel_withLoss(model,loss)
gt = torch.rand(2,3,10,10)
input = torch.rand(2,1,10,10)
loss,_ = model(gt,input)
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()

Since you call Tensor.to in the body of the forward function it will always live on that device. The DataParallel module does the following for you: replicate your model across N devices, split your input tensor(s) in N chunks (across dimension 0 by default), run forward on each of the replicas, and concatenate the result. Every model replica where the forward function is called is already pinned to a device, so you shouldn’t explicitly call to(device) on intermediate tensors.

thank you very much !!!I will try and see whether it works.