How to collect all the gradients from multiple GPUs

Hi!

I am trying to run this code:


to visualize saliency map from ResNet. This code is written for CPU. I modify a little bit to put all the models and tensors to cuda to run it on GPUs. The main result I care about is in line 65 (the self.gradients).

The code works fine on single GPU. However, when I run the code on multiple GPUs with input size to be: 64x3x32x32 (cifar10 image dataset), the results I get is: 16x3x32x32 (it should be: 64x3x32x32).

To me, the problem seems to be on line 35: the register_backward_hook function failed to collect all the gradients from all the GPUs but the last one.

Am I doing something wrong or is this a known bug for PyTorch? If so, is there any way around for this issue?

Thank you very much!

Here in your code you’re setting

def hook_function(module, grad_in, grad_out):
    self.gradients = grad_in[0]

I think this happens on each GPU, so in the end you only get one-fourth of what you should have gotten (assuming 4 gpus).

You can try defining self.gradients as a python list, and then appending to it:

def hook_function(module, grad_in, grad_out):
    self.gradients.append(grad_in[0])
1 Like

Hi @richard, Thanks a lot for your help!

Yes, I tested your method and it is working perfectly! Thanks a lot!

I have another question. When I use the code on images with batch size 128, the memory of GPU is blown up. I then got an out of memory error. Do you have any suggestions on that?

Thanks again!

Other than shrinking the batch size, not really, sorry. Maybe someone else can weigh in here about how to better work with OOMs.

1 Like

Hi!

I’m trying to run this code:

    def forward(self, x):
        self.activations = []
        self.gradients = []
        self.grad_index = 0
        self.activation_to_layer = {}
        self.data = []

        activation_index = 0

        for layer, module in self.model.named_modules():
            if ('conv' in layer) or ('pool' in layer) or ('fc' in layer):
                if 'fc6' in layer:
                    if isinstance(self.model, nn.DataParallel):
                        x = x.view(-1, self.model.module.fc6.in_features)
                    else:
                        x = x.view(-1, self.model.fc6.in_features)
                x = module(x)
            if isinstance(module, torch.nn.modules.conv.Conv3d):
                # hook will registered on the output of the layer
                print(module)
                x.register_hook(self.compute_rank)
                self.activations.append(x)
                self.activation_to_layer[activation_index] = layer
                activation_index += 1
                x = model.relu(x)
            elif isinstance(module, torch.nn.modules.Linear) and layer != 'fc8':
                x = model.dropout(model.relu(x))

        return x

    def compute_rank(self, grad):
        """
        Compute the Taylor expansion without abs of each channel
        return:
            self.activations: feature map before relu in each layer
            self.filter_ranks: Taylor value without abs over spatial and batch
        """
        activation_index = len(self.activations) - self.grad_index - 1
        activation = self.activations[activation_index]
        if self.pruning_level == 'channel':
            self.data.append(grad)
            print(len(self.data))

The self.model is a Dataparallel object. However, it will only run on GPU 0. Am I doing something wrong?
The output of this code:

Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
1
2
3
4
5
6
7
8

and nvidia-smi command suggest that it only run on the first GPU:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.130                Driver Version: 384.130                   |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 1080    Off  | 00000000:05:00.0 Off |                  N/A |
| 43%   71C    P2   194W / 180W |   7353MiB /  8112MiB |     98%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 1080    Off  | 00000000:06:00.0 Off |                  N/A |
| 34%   61C    P0    37W / 180W |     10MiB /  8114MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce GTX 1080    Off  | 00000000:09:00.0 Off |                  N/A |
| 27%   57C    P0    38W / 180W |     10MiB /  8114MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  GeForce GTX 1080    Off  | 00000000:0A:00.0 Off |                  N/A |
| 24%   49C    P0    36W / 180W |     10MiB /  8114MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     16765      C   python                                      7343MiB |
+-----------------------------------------------------------------------------+

Please help me!