Using hooks with `DataParallel` (autograd error)

I have a problem with nn.DataParallel . I am trying to get intermediate feature maps via hooks from a DataParallel model and use these feature maps to compute loss.

import torch.nn as nn
import torch.nn.functional as F
import torch

class DeepInversionFeatureHook():
    '''
    Implementation of the forward hook to track feature statistics and compute a loss on them.
    Will compute mean and variance, and will use l2 as a loss
    '''

    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        # hook co compute deepinversion's feature distribution regularization
        nch = input[0].shape[1]
        mean = input[0].mean([0, 2, 3])
        var = input[0].permute(1, 0, 2, 3).contiguous().view(
            [nch, -1]).var(1, unbiased=False)

        # forcing mean and variance to match between two distributions
        # other ways might work better, i.g. KL divergence
        r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
            module.running_mean.data - mean, 2)

        self.r_feature = r_feature
        # must have no output

    def close(self):
        self.hook.remove()
        
        
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.in_planes = 64
 
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        
        
        # regist hooks
        self.loss_r_feature_layers = []
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d):
                self.loss_r_feature_layers.append(
                    DeepInversionFeatureHook(module))
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        
        loss_r_feature = sum([mod.r_feature
                    for (idx, mod) in enumerate(self.loss_r_feature_layers)])
        
        return x, loss_r_feature
    
net = nn.DataParallel(MyNet().cuda())
# net = MyNet().cuda()

output, extra_loss = net(torch.randn(512, 3, 32, 32).cuda())

print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())

# loss = F.mse_loss(output, torch.ones_like(output))
# print('=> mse loss:', loss.size(), loss.device, loss)
print('=> output:', output.size(), output.device, output.grad_fn)
loss = extra_loss.sum()

loss.backward()
net.zero_grad()

Then, I got

=> extra_loss: torch.Size([8]) cuda:0 tensor([11.8416, 11.8376, 11.8413, 11.8440, 11.8397, 11.8390, 11.8433, 11.8395],
       device='cuda:0', grad_fn=<GatherBackward>)
=> extra_loss sum: torch.Size([]) cuda:0 tensor(94.7259, device='cuda:0', grad_fn=<SumBackward0>)
=> output: torch.Size([512, 64, 32, 32]) cuda:0 <torch.autograd.function.GatherBackward object at 0x7f246222cac8>
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-34c98315eba0> in <module>
      9 loss = extra_loss.sum()
     10 
---> 11 loss.backward()
     12 net.zero_grad()

/opt/conda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:3 (validate_outputs at ../torch/csrc/autograd/engine.cpp:561)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6c (0x7f23ac19e6fc in /opt/conda/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x27dacf4 (0x7f241d705cf4 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #2: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x4f5 (0x7f241d7077b5 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x4df (0x7f241d70971f in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xc0 (0x7f241d6febe0 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x50 (0x7f242110cfb0 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xc819d (0x7f245fa7b19d in /opt/conda/bin/../lib/libstdc++.so.6)
frame #7: <unknown function> + 0x76db (0x7f24631ab6db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #8: clone + 0x3f (0x7f2462ed488f in /lib/x86_64-linux-gnu/libc.so.6)

This bug is very wired. The first 2 iterations just works without any error. However, after the first 2 iterations, you can see this bug forever.