Multi GPU Hook not correctly doing filling buffer

Sorry, I kept a bit of extra code since I started with https://github.com/pytorch/examples/blob/master/mnist/main.py. and thought it would be easy to compare to that. I have further reduced the code below to only have the exact things to reproduce this problem. My network is 1 layer, the ‘main’ just inititalizes things and runs backprop on one random input. The only thing that is unique is my saveAverageD function and 2 custom modules. The custom modules and saveAverageD function were made by your and @ptrblck 's reccomendations over a number of previous posts and they work perfectly on 1 GPU. But as you can see it does not translate to DataParallel and 2 GPUs.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def saveAverageD(grad_out, Values):
    print('calling with:')
    print(grad_out)
    with torch.no_grad():
        Values.average = Values.average * 0.99 + grad_out.sum((0,2,3)) * 0.01
        
class valueTracker(nn.Module):
    def __init__(self, out_channels):
        super(valueTracker, self).__init__()
        self.register_buffer('average', torch.zeros(out_channels, device=device, dtype=torch.double))

class averageSaveConv(nn.Module):
    def __init__(self, startLayer, out_channels):
        super(averageSaveConv, self).__init__()
        self.values = valueTracker(out_channels)
        self.layer = startLayer.double()
        
    def forward(self,x):
        out = self.layer(x)
        out.register_hook(lambda grad: saveAverageD(grad, self.values))
        return out

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = averageSaveConv(nn.Conv2d(1, 10, 28,1),10)                

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        output = F.log_softmax(x, dim=1)
        return output
    
if __name__ == '__main__':
    
    #setup the net and data parallel and optimizer in the most basic way
    model = Net()
    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    optimizer.zero_grad()
    data, target = torch.rand((2,1,28,28),dtype=torch.float64).to(device), torch.zeros(2).to(device).long()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    print('Conv Average Gradient:')
    print(model.module.conv.values.average.tolist())