Aggregating the results of Forward / backward hook on nn.DataParallel (multi-GPU)

Hi, I want to use hook function on DataParallel instance, but I have some uncertain points.
I think below code would work well, but I’m not sure about

  1. Do I need to set lock when I access self.target_outputs and self.target_outputs_grad in the hook function?
  2. Is it guaranteed that inputs are scattered to each GPU following original order? For example, if the inputs were [1, 2, 3, 4] and there are 2 GPUs, Do [1, 2] are fed to GPU #1 and [3, 4] are fed to GPU #2?
class Wrapper:

    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.target_outputs = {}
        self.target_outputs_grad = {}

        def forward_hook(_, __, output):
            self.target_outputs[output.device] = output.detach()

        def backward_hook(_, __, grad_output):
            assert len(grad_output) == 1
            self.target_outputs_grad[grad_output[0].device] = grad_output[0].detach()

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

Any comments would be appreciated! Thanks!

3 Likes

For Question 1, I found that the pytorch code aggregating output of each (copied) module is using lock.


So the answer would be I have to use threading.lock.

For Question 2, I tracked the scatter of inputs.

inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
=> scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
=> Scatter.apply(target_gpus, None, dim, obj)
=> comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
=> tuple(torch. C. scatter(tensor, devices, chunk_sizes, dim, streams))
=> scatter(tensor, devices, chunk_sizes, dim, streams);
=> tensor.chunk(/ chunks= /devices.size(), / dim= /dim);
=> self.split(split_size, dim);
=> self.narrow(dim, i * split_size, length);

It seems inputs are scattered to GPU in its original order, so aggregating the outputs of each GPU in GPU # order would be matched with original inputs. Indeed, I can find it at gather of nn.DataParallel.

So would you please show us what your code looks like finally.
I also tried to get intermedia outputs by forward hook using multi-gpus but a weird thing happened.

  • In init, I initialize self.target_output as None.
  • In hook function, I print self.target_output which is not None when hook function is excuted.
  • But after self.model.forward() excuted, self.target_output turn back to be None.

One of two forwards output None which is vary weird.

I implement this using data_parellel function and I regard class Wrapper as a NN Module, which return self.target_output. So data_parallell will scatter and gather all the output of Wrapper.

code looks like:

import torch
from torchvision.models.vgg import vgg19

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper, self).__init__()
        self.model = model
        self.target_outputs = None

        def forward_hook(_, __, output):
            self.target_outputs = output.detach()
        self.model.features[2].register_forward_hook(forward_hook)

    def forward(self,input):
        self.model(input)
        return self.target_outputs


model = vgg19()
model = model.cuda(4)
wrapper = Wrapper(model)

devices = [4, 5]

inputs = torch.randn(60,3,224,224).fill_(0).cuda(4)
out = torch.nn.parallel.data_parallel(wrapper, inputs, devices)
print("first forward:   ", out)
inputs = torch.randn(60,3,224,224).fill_(1).cuda(4)
out = torch.nn.parallel.data_parallel(wrapper, inputs, devices)
print("second forward: ", out.shape)

output is:

first forward:  None
second forward: (60, 64, 224, 224)

By compare the result with single-gpu result, I found the result of second forward is indeed the result of first forward .

1 Like

There is a solution to get intermedia output by forward hooks using multi-gpus in this post although not so elegant. It dose work in my test.
But the weird phenomenon explained in above post still confuses me.

@wwiiiii Hi,do you know how does Scatter.apply fuction call Scatter.forward function?I am confused about this.

This answer may come late, but I hope it will serve someone. Or if someone finds a better way.

I solved using thread identificator. It is not the best way, but it has worked for me. This way you will avoid blocking two processes.

import torch
from torchvision.models.vgg import vgg19
import threading
from collections import defaultdict

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper, self).__init__()
        self.model = model
        self.target_outputs = defaultdict(lambda: None)

        def forward_hook(_, __, output):
            self.target_outputs[threading.get_native_id()] = output.detach()
        self.model.features[2].register_forward_hook(forward_hook)

    def forward(self,input):
        self.model(input)

        thread_id = threading.get_native_id()
        result = self.target_outputs[thread_id]
        del self.target_outputs[thread_id]

        return result


model = vgg19()
model = model.cuda(4)
wrapper = Wrapper(model)

devices = [4, 5]

inputs = torch.randn(60,3,224,224).fill_(0).cuda(4)
out = torch.nn.parallel.data_parallel(wrapper, inputs, devices)
print("first forward:   ", out)
inputs = torch.randn(60,3,224,224).fill_(1).cuda(4)
out = torch.nn.parallel.data_parallel(wrapper, inputs, devices)
print("second forward: ", out.shape)

Hi,
Thanks for the amazing solution. It works correctly for the forward pass.
I am using forward along with backward and doing certain operations. So, each time, when the backward is called, it creates a new thread and the output of backward is saved on that thread. My problem is how do I get the backward thread ID for the corresponding forward thread ID.

self.forward_hook = self.target_layer.register_forward_hook(self.hook_fn_act)
self.backward_hook = self.target_layer.register_backward_hook(self.hook_fn_grad)

def hook_fn_act(self, module, input, output):
    self.activations[threading.get_native_id()] = output.detach()

def hook_fn_grad(self, module, grad_input, grad_output):
    self.gradients[threading.get_native_id()] = grad_output[0].detach()

def foo():
    <Certain operations using the self.activations and self.gradients>
    <Problem is that thread id for the activations and gradients is not same, so how to know which 
      thread ID of activations maps to which one of gradients>

Thanks for the great solution.

For python3.7, we need to use threading.get_ident() because get_native_id() is only defined in python3.8.

The following code works for me.

In [4]: import torch
   ...: import threading
   ...: from collections import defaultdict
   ...: from torchvision.models.vgg import vgg19
   ...: 
   ...: class Wrapper(torch.nn.Module):
   ...:     def __init__(self, model):
   ...:         super(Wrapper, self).__init__()
   ...:         self.model = model
   ...:         self.target_outputs = {}
   ...: 
   ...:         def forward_hook(_, __, output):
   ...:             self.target_outputs[threading.get_ident()] = output.detach()
   ...:         self.model.features[2].register_forward_hook(forward_hook)
   ...: 
   ...:     def forward(self,input):
   ...:         self.model(input)
   ...:         thread_id = threading.get_ident()
   ...:         result = self.target_outputs[thread_id]
   ...:         del self.target_outputs[thread_id]
   ...:         return result
   ...: 
   ...: 
   ...: model = vgg19()
   ...: model = model.cuda(0)
   ...: wrapper = Wrapper(model)
   ...: 
   ...: devices = [0, 1]
   ...: 
   ...: inputs = torch.randn(60,3,224,224).fill_(0).cuda(1)
   ...: out = torch.nn.parallel.data_parallel(wrapper, inputs, devices)
   ...: print("first forward:   ", out.shape)
   ...: inputs = torch.randn(60,3,224,224).fill_(1).cuda(1)
   ...: out = torch.nn.parallel.data_parallel(wrapper, inputs, devices)
   ...: print("second forward: ", out.shape)

The result is

first forward:    torch.Size([60, 64, 224, 224])
second forward:  torch.Size([60, 64, 224, 224])