Sorry, I kept a bit of extra code since I started with https://github.com/pytorch/examples/blob/master/mnist/main.py. and thought it would be easy to compare to that. I have further reduced the code below to only have the exact things to reproduce this problem. My network is 1 layer, the ‘main’ just inititalizes things and runs backprop on one random input. The only thing that is unique is my saveAverageD function and 2 custom modules. The custom modules and saveAverageD function were made by your and @ptrblck 's reccomendations over a number of previous posts and they work perfectly on 1 GPU. But as you can see it does not translate to DataParallel and 2 GPUs.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
def saveAverageD(grad_out, Values):
print('calling with:')
print(grad_out)
with torch.no_grad():
Values.average = Values.average * 0.99 + grad_out.sum((0,2,3)) * 0.01
class valueTracker(nn.Module):
def __init__(self, out_channels):
super(valueTracker, self).__init__()
self.register_buffer('average', torch.zeros(out_channels, device=device, dtype=torch.double))
class averageSaveConv(nn.Module):
def __init__(self, startLayer, out_channels):
super(averageSaveConv, self).__init__()
self.values = valueTracker(out_channels)
self.layer = startLayer.double()
def forward(self,x):
out = self.layer(x)
out.register_hook(lambda grad: saveAverageD(grad, self.values))
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = averageSaveConv(nn.Conv2d(1, 10, 28,1),10)
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x, 1)
output = F.log_softmax(x, dim=1)
return output
if __name__ == '__main__':
#setup the net and data parallel and optimizer in the most basic way
model = Net()
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
data, target = torch.rand((2,1,28,28),dtype=torch.float64).to(device), torch.zeros(2).to(device).long()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
print('Conv Average Gradient:')
print(model.module.conv.values.average.tolist())