You should be able to use, maybe with minor modifications, the one I posted above. It works, at least in 0.4.0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class FullModel(nn.Module):
def __init__(self, model, loss):
super(FullModel, self).__init__()
self.model = model
self.loss = loss
def forward(self, targets, *inputs):
outputs = self.model(*inputs)
loss = self.loss(outputs, targets)
return torch.unsqueeze(loss,0),outputs
def DataParallel_withLoss(model,loss,**kwargs):
model=FullModel(model, loss)
if 'device_ids' in kwargs.keys():
device_ids=kwargs['device_ids']
else:
device_ids=None
if 'output_device' in kwargs.keys():
output_device=kwargs['output_device']
else:
output_device=None
if 'cuda' in kwargs.keys():
cudaID=kwargs['cuda']
model=torch.nn.DataParallel(model, device_ids=device_ids, output_device=output_device).cuda(cudaID)
else:
model=torch.nn.DataParallel(model, device_ids=device_ids, output_device=output_device).cuda()
return model
class toy(nn.Module):
def __init__(self):
super(toy, self).__init__()
self.conv2d = torch.nn.Conv2d(1,3,1)
def forward(self,x):
return self.conv2d(x)
model = toy()
optimizer = torch.optim.SGD(model.parameters(),lr=1)
loss = torch.nn.L1Loss()
model = DataParallel_withLoss(model,loss)
gt = torch.rand(2,3,10,10)
input = torch.rand(2,1,10,10)
loss,_ = model(gt,input)
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
toy example