Hi,I want to use dataparallel.But I come across this problem.I usr forward to get my output ,then I use fake_reply = fake_reply.to(device)(fake_reply is the output ).
But I find that the output is not distributed on device 0 and 1(I set CUDA_VISIBLE_DEVICES=0,1 ).
class D_FullModel(nn.Module):
def __init__(self, model, gen,D_loss):
super(D_FullModel, self).__init__()
self.model = model
self.gen = gen
self.loss = D_loss
def forward(self, targets, inputs):
#loss_real = loss(real_r, real_labels)
fake_labels = torch.from_numpy(np.random.uniform(0, 0.3, size=(BATCH_SIZE))).float().to(device)
real_labels = torch.from_numpy(np.random.uniform(0.7, 1.2, size=(BATCH_SIZE))).float().to(device)
fake_reply, _, _ =self.gen.sample(inputs, targets)
fake_reply = fill_with_padding(fake_reply, EOU, PAD).detach()
fake_reply = fake_reply.to(device)
print(fake_reply)
real_r = self.model.batchClassify(targets, inputs)
fake_r = self.model.batchClassify(fake_reply, inputs)
print(fake_r)
x = torch.cat((fake_r, real_r), 0)
y = torch.cat((fake_labels, real_labels), 0)
loss = self.loss(x, y)
return torch.unsqueeze(loss, 0), fake_reply
def D_DataParallel_withLoss(model, gen, D_loss):
model = D_FullModel(model, gen, D_loss)
model = torch.nn.DataParallel(model).to(device)
return model
I found that fake_reply = fake_reply.to(device)
is all on the device(‘0’).and inputs is distributed on device 0 and 1.Then the fake_r = self.model.batchClassify(fake_reply, inputs)
doesn`t work.
How should I change the code to make the loss function work?
My code immitated the code in https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/21,solution,
import torch
import torch.nn as nn
class FullModel(nn.Module):
def __init__(self, model, loss):
super(FullModel, self).__init__()
self.model = model
self.loss = loss
def forward(self, targets, *inputs):
outputs = self.model(*inputs)
loss = self.loss(outputs, targets)
return torch.unsqueeze(loss,0),outputs
def DataParallel_withLoss(model,loss,**kwargs):
model=FullModel(model, loss)
if 'device_ids' in kwargs.keys():
device_ids=kwargs['device_ids']
else:
device_ids=None
if 'output_device' in kwargs.keys():
output_device=kwargs['output_device']
else:
output_device=None
if 'cuda' in kwargs.keys():
cudaID=kwargs['cuda']
model=torch.nn.DataParallel(model, device_ids=device_ids, output_device=output_device).cuda(cudaID)
else:
model=torch.nn.DataParallel(model, device_ids=device_ids, output_device=output_device).cuda()
return model
class toy(nn.Module):
def __init__(self):
super(toy, self).__init__()
self.conv2d = torch.nn.Conv2d(1,3,1)
def forward(self,x):
return self.conv2d(x)
model = toy()
optimizer = torch.optim.SGD(model.parameters(),lr=1)
loss = torch.nn.L1Loss()
model = DataParallel_withLoss(model,loss)
gt = torch.rand(2,3,10,10)
input = torch.rand(2,1,10,10)
loss,_ = model(gt,input)
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()