Multi GPU Hook not correctly doing filling buffer

Reposting of an old question I never got a reply to Multi GPU Hook not correctly filling buffer. I know runnable minimal examples are the way to go so I now I have a full runnable minimal example of the code written. Please take a look and help me get to the bottom of this. This is what it looks like with two runs:

> CUDA_VISIBLE_DEVICES=1 python mainMinimalExampleMultiGPU.py 
Conv Average Gradient:
[0.00044749726757895453, 0.0014000369330415242, -0.0008686411516918384]
fc Average Gradient:
[-0.004141018711068057, 0.0015833583892040112, 0.0011787552821185693, 0.0010372935249398085, -0.004048425233274684, 0.0006052607123126522, 0.0013055124756185216, 0.0007034393619838467, 0.0007521140892023609, 0.0010237101089629694]
> CUDA_VISIBLE_DEVICES=0,1 python mainMinimalExampleMultiGPU.py 
Conv Average Gradient:
[0.0, 0.0, 0.0]
fc Average Gradient:
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

The following is the entire program which was started with the mnist example script.

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def saveAverageD(grad_out, Values):
    with torch.no_grad():
        if(len(grad_out.shape) == 2):
            Values[0].average = Values[0].average * 0.99 + grad_out.sum((0)) * 0.01
        else:
            Values[0].average = Values[0].average * 0.99 + grad_out.sum((0,2,3)) * 0.01
        
class valueTracker(nn.Module):
    def __init__(self, out_channels):
        super(valueTracker, self).__init__()
        self.register_buffer('average', torch.zeros(out_channels, device=device, dtype=torch.double))

class averageSaveConv(nn.Module):
    def __init__(self, startLayer, out_channels):
        super(averageSaveConv, self).__init__()
        self.values = nn.ModuleList([])
        self.values.append(valueTracker(out_channels))
        self.layer = startLayer.double()
    def forward(self,x):
        out = self.layer(x)
        out.register_hook(lambda grad: saveAverageD(grad, self.values))
        return out

        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = averageSaveConv(nn.Conv2d(1, 3, 5, 1),3)                
        self.fc = averageSaveConv(nn.Linear(432, 10),10)

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device).double(), target.to(device).long()
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        
        print('Conv Average Gradient:')
        print(model.module.conv.values[0].average.tolist())
        print('fc Average Gradient:')
        print(model.module.fc.values[0].average.tolist())

        exit()#just here for debugging
        
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device).double(), target.to(device).long()
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=2, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    args = parser.parse_args()

    torch.manual_seed(1)


    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = Net()
    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())).to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)


if __name__ == '__main__':
    main()

@ptrblck @albanD Any chance either of you can take a look at this? I got the minimal example up so hopefully it will be easy to replicate the problem.

Just bumping again since I keep posting this on weekends. Please @ptrblck @albanD, you’re my only hope.

Hi,

I did see the issue when you pinged the first time but I don’t think I have much to say about it.
I would advise that you try and reduce your code as much as possible. Given all the things that happen there, I have no idea what could be going wrong :confused:

Sorry, I kept a bit of extra code since I started with https://github.com/pytorch/examples/blob/master/mnist/main.py. and thought it would be easy to compare to that. I have further reduced the code below to only have the exact things to reproduce this problem. My network is 1 layer, the ‘main’ just inititalizes things and runs backprop on one random input. The only thing that is unique is my saveAverageD function and 2 custom modules. The custom modules and saveAverageD function were made by your and @ptrblck 's reccomendations over a number of previous posts and they work perfectly on 1 GPU. But as you can see it does not translate to DataParallel and 2 GPUs.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def saveAverageD(grad_out, Values):
    print('calling with:')
    print(grad_out)
    with torch.no_grad():
        Values.average = Values.average * 0.99 + grad_out.sum((0,2,3)) * 0.01
        
class valueTracker(nn.Module):
    def __init__(self, out_channels):
        super(valueTracker, self).__init__()
        self.register_buffer('average', torch.zeros(out_channels, device=device, dtype=torch.double))

class averageSaveConv(nn.Module):
    def __init__(self, startLayer, out_channels):
        super(averageSaveConv, self).__init__()
        self.values = valueTracker(out_channels)
        self.layer = startLayer.double()
        
    def forward(self,x):
        out = self.layer(x)
        out.register_hook(lambda grad: saveAverageD(grad, self.values))
        return out

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = averageSaveConv(nn.Conv2d(1, 10, 28,1),10)                

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        output = F.log_softmax(x, dim=1)
        return output
    
if __name__ == '__main__':
    
    #setup the net and data parallel and optimizer in the most basic way
    model = Net()
    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    optimizer.zero_grad()
    data, target = torch.rand((2,1,28,28),dtype=torch.float64).to(device), torch.zeros(2).to(device).long()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    print('Conv Average Gradient:')
    print(model.module.conv.values.average.tolist())

Just so everything is on this page this is a summary of what should be happening:

saveAverageD is a function to be called in a backward hook to keep track of the average value in grad_out

valueTracker just has a single buffer to store this average

averageSaveConv: this keeps track of a single layer and a single instance of valueTracker to be used for average saving

I edited the code to have an extra print. The following is now the output. Again, this works exactly as you helped me get to with 1 GPU but fails with 2 GPUs. As you can see with 1 GPU it loads both random inputs together and computes the average and saves with saveAverageD called one time on the one GPU. on 2GPUs it calls saveAverageD twice, each GPU gets one of the inputs, but then the average buffer is not tracked.

$ CUDA_VISIBLE_DEVICES=1 python mainMinimalExampleMultiGPU.py 
calling with:
tensor([[[[-0.4504]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0442]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0555]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0622]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0781]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0507]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0317]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0323]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0349]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0607]]],                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                               
        [[[-0.4533]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0368]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0420]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0848]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0637]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0569]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0421]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0330]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0346]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0594]]]], device='cuda:0', dtype=torch.float64)                                                                                                                                                                                                                  
Conv Average Gradient:                                                                                                                                                                                                                                                         
[-0.009037267504024003, 0.0008102439568684652, 0.0009758195369250783, 0.001469522990143697, 0.001417696908934414, 0.0010763320891902935, 0.0007385695544331516, 0.0006529860781865106, 0.0006942532124915154, 0.0012018431768508766]                                           


$ CUDA_VISIBLE_DEVICES=0,1 python mainMinimalExampleMultiGPU.py                                                                                                                                               
calling with:                                                                                                                                                                                                                                                                  
calling with:                                                                                                                                                                                                                                                                  
tensor([[[[-0.4665]],                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                               
         [[ 0.0500]],

         [[ 0.0467]],

         [[ 0.0672]],

         [[ 0.0635]],

         [[ 0.0213]],

         [[ 0.0543]],

         [[ 0.0817]],

         [[ 0.0586]],

         [[ 0.0232]]]], device='cuda:1', dtype=torch.float64)
tensor([[[[-0.4679]],

         [[ 0.0354]],

         [[ 0.0464]],

         [[ 0.0746]],

         [[ 0.0570]],

         [[ 0.0386]],

         [[ 0.0673]],

         [[ 0.0740]],

         [[ 0.0512]],

         [[ 0.0235]]]], device='cuda:0', dtype=torch.float64)
Conv Average Gradient:
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

Is my new code sample reduced enough? I don’t know how I can reduce it any further.

Ho I didn’t saw that you updated the code in Multi GPU Hook not correctly doing filling buffer and not the original post.

This looks ok. I don’t have a multigpu machine to run it though :confused: @ptrblck would you have a minute to check that please?

1 Like

As mentioned in my previous post, I don’t think this will work out of the box without reducing the value somehow manually.
Hooks would be registered on each replica and thus would only be valid for this model only.
I don’t think that nn.DataParallel or DDP reduces hook by default, but might be wrong, so I still think that your best bet would be to use e.g. torch.nn.parallel.gather.

2 Likes

The only information I could find about parallel.gather is from this parallelism tutorial. I think you’re right that might be what I want to be using, But I don’t think I understand how to use it correctly. I’ve tried a few different ways to rework the code but when I call nn.parallel.gather(model.module.conv.values.average, 0) it still is always returning all 0’s with 2GPUs. Could you please provide any additional help? The backword hook is already in the forward call so it should be on both devices. I tried initializing the array after dataparallel as well. and I tried the dataprallel subclass from that tutorial fixed with the info here

I think I got it! This seems super inefficient though. Can one of you please confirm this is what you meant? Is this the right way to do this and its whats happening behind the scenes with dataparallel anyway? I am specifically concerned about everything in Net(). I do know mathematically I wouldn’t have to ‘split’ every iteration if all I was actually doing in the non-minimal example was averaging. And does this mean if I want to use dataparallel with these buffers I need to manually wrap every layer rather than just calling dataparallel on the network? Sorry, im just generally getting the feeling I did not do this right.

This is the new output:

$ CUDA_VISIBLE_DEVICES=0,1 python mainMinimalExampleMultiGPU.py 
calling with:                                                                                                                                                                                                                                                                   
calling with:                                                                                                                                                                                                                                                                   
[[[[-0.41571250258624853]], [[0.04142077242016727]], [[0.06950999356652776]], [[0.03586417237630698]], [[0.04916706555685479]], [[0.0479719964996224]], [[0.04534353418505202]], [[0.036737063334092955]], [[0.04897484751185222]], [[0.04072305713577202]]]]                   
[[[[-0.4097441719176553]], [[0.05437682468935652]], [[0.056508038808905765]], [[0.04807834249760404]], [[0.04681839171079582]], [[0.04986861282603519]], [[0.0284716049455237]], [[0.04319707706917308]], [[0.053327966629887764]], [[0.02909731274037328]]]]                   
Conv Average Gradient 0:                                                                                                                                                                                                                                                        
[-0.041571250258624855, 0.004142077242016727, 0.006950999356652776, 0.003586417237630698, 0.004916706555685479, 0.0047971996499622405, 0.004534353418505202, 0.0036737063334092955, 0.004897484751185222, 0.004072305713577202]                                                 
Conv Average Gradient 1:                                                                                                                                                                                                                                                        
[-0.04097441719176553, 0.005437682468935653, 0.005650803880890577, 0.004807834249760404, 0.004681839171079582, 0.004986861282603519, 0.00284716049455237, 0.0043197077069173076, 0.005332796662988777, 0.002909731274037328]                                                    
Conv Average Gradient:                                                                                                                                                                                                                                                          
[-0.04127283372519519, 0.00478987985547619, 0.0063009016187716765, 0.004197125743695551, 0.004799272863382531, 0.00489203046628288, 0.003690756956528786, 0.003996707020163302, 0.005115140707086999, 0.003491018493807265]                                                     
updated Conv Average Gradient 0:                                                                                                                                                                                                                                                
[-0.04127283372519519, 0.00478987985547619, 0.0063009016187716765, 0.004197125743695551, 0.004799272863382531, 0.00489203046628288, 0.003690756956528786, 0.003996707020163302, 0.005115140707086999, 0.003491018493807265]                                                     
updated Conv Average Gradient 1:                                                                                                                                                                                                                                                
[-0.04127283372519519, 0.00478987985547619, 0.0063009016187716765, 0.004197125743695551, 0.004799272863382531, 0.00489203046628288, 0.003690756956528786, 0.003996707020163302, 0.005115140707086999, 0.003491018493807265]                                                     

$ CUDA_VISIBLE_DEVICES=0 python mainMinimalExampleMultiGPU.py                                                                                                                                       
calling with:                                                                                                                                                                                                                                                                   
[[[[-0.41571250258624853]], [[0.04142077242016727]], [[0.06950999356652776]], [[0.03586417237630698]], [[0.04916706555685479]], [[0.0479719964996224]], [[0.04534353418505202]], [[0.036737063334092955]], [[0.04897484751185222]], [[0.04072305713577202]]], [[[-0.4097441719176553]], [[0.05437682468935652]], [[0.056508038808905765]], [[0.04807834249760404]], [[0.04681839171079582]], [[0.04986861282603519]], [[0.0284716049455237]], [[0.04319707706917308]], [[0.053327966629887764]], [[0.02909731274037328]]]]                                      
Conv Average Gradient 0:                                                                                                                                                                                                                                                        
[-0.04127283372519519, 0.00478987985547619, 0.0063009016187716765, 0.004197125743695552, 0.004799272863382531, 0.00489203046628288, 0.003690756956528786, 0.003996707020163302, 0.005115140707087, 0.0034910184938072653]                                                       
Conv Average Gradient:                                                                                                                                                                                                                                                          
[-0.04127283372519519, 0.00478987985547619, 0.0063009016187716765, 0.004197125743695552, 0.004799272863382531, 0.00489203046628288, 0.003690756956528786, 0.003996707020163302, 0.005115140707087, 0.0034910184938072653]                                                       
updated Conv Average Gradient 0:                                                                                                                                                                                                                                                
[-0.04127283372519519, 0.00478987985547619, 0.0063009016187716765, 0.004197125743695552, 0.004799272863382531, 0.00489203046628288, 0.003690756956528786, 0.003996707020163302, 0.005115140707087, 0.0034910184938072653]  ```

and this is the new code for the minimal example:

import torch                                                                                                                                                                                                                                                                    
import torch.nn as nn                                                                                                                                                                                                                                                           
import torch.nn.functional as F                                                                                                                                                                                                                                                 
import torch.optim as optim                                                                                                                                                                                                                                                     
from torchvision import datasets, transforms                                                                                                                                                                                                                                    
import random                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                
use_cuda = torch.cuda.is_available()                                                                                                                                                                                                                                            
device = torch.device("cuda" if use_cuda else "cpu")                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                
def saveAverageD(grad_out, Values):                                                                                                                                                                                                                                             
    print('calling with:')                                                                                                                                                                                                                                                      
    print(grad_out.tolist())                                                                                                                                                                                                                                                    
    with torch.no_grad():                                                                                                                                                                                                                                                       
        Values.average = Values.average * 0.9 + grad_out.mean((0,2,3)) * 0.1                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                
class valueTracker(nn.Module):                                                                                                                                                                                                                                                  
    def __init__(self, out_channels):                                                                                                                                                                                                                                           
        super(valueTracker, self).__init__()                                                                                                                                                                                                                                    
        self.register_buffer('average', torch.zeros(out_channels, device=device, dtype=torch.double))                                                                                                                                                                           
                                                                                                                                                                                                                                                                                
class averageSaveConv(nn.Module):                                                                                                                                                                                                                                               
    def __init__(self, startLayer, out_channels):                                                                                                                                                                                                                               
        super(averageSaveConv, self).__init__()                                                                                                                                                                                                                                 
        self.layer = startLayer.double()                                                                                                                                                                                                                                        
        self.out_channels = out_channels                                                                                                                                                                                                                                        
        self.values=valueTracker(self.out_channels)                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                
    def forward(self,x):                                                                                                                                                                                                                                                        
        out = self.layer(x)                                                                                                                                                                                                                                                     
        out.register_hook(lambda grad: saveAverageD(grad, self.values))                                                                                                                                                                                                         
        return out                                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                                
if(torch.cuda.device_count() == 1):                                                                                                                                                                                                                                             
    device_ids = [0]                                                                                                                                                                                                                                                            
else:                                                                                                                                                                                                                                                                           
    device_ids = [0,1]                                                                                                                                                                                                                                                          
class Net(nn.Module):                                                                                                                                                                                                                                                           
    def __init__(self):
        super(Net, self).__init__()
        self.conv = averageSaveConv(nn.Conv2d(1, 10, 28,1),10).to(device)
        self.replicas = nn.parallel.replicate(self.conv, device_ids)
    def forward(self, x):
        output_device = 0
        inputs = nn.parallel.scatter(x, device_ids)
        replicas = self.replicas[:len(inputs)]
        outputs = nn.parallel.parallel_apply(replicas, inputs)
        x = nn.parallel.gather(outputs, output_device)
        x = torch.flatten(x, 1)
        output = F.log_softmax(x, dim=1)
        return output
    def gather(self):
        self.conv.values.average = nn.parallel.gather([self.replicas[x].values.average for x in range(len(device_ids))], self.conv.values.average.device).reshape(len(device_ids),-1).mean(0)
    def split(self):
        for replica in self.replicas:
            replica.values.average.copy_(self.conv.values.average)

if __name__ == '__main__':
    random_seed = 4
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)    
    model = Net()
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    optimizer.zero_grad()
    data, target = torch.rand((2,1,28,28),dtype=torch.float64).to(device), torch.zeros(2).to(device).long()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    model.gather()
    optimizer.step()
    for id in device_ids:
        print('Conv Average Gradient %d:' % id)
        print(model.replicas[id].values.average.tolist())
    print('Conv Average Gradient:')
    print(model.conv.values.average.tolist())
    model.split()
    for id in device_ids:
        print('updated Conv Average Gradient %d:' % id)
        print(model.replicas[id].values.average.tolist())

@ptrblck thoughts? The things in Net and manually calling my gather and split functions doesn’t seem like the correct usage of these functions. And does this mean I can’t use DataParallel anymore?