I have a problem with nn.DataParallel
. I am trying to get intermediate feature maps via hooks from a DataParallel
model and use these feature maps to compute loss.
import torch.nn as nn
import torch.nn.functional as F
import torch
class DeepInversionFeatureHook():
'''
Implementation of the forward hook to track feature statistics and compute a loss on them.
Will compute mean and variance, and will use l2 as a loss
'''
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
# hook co compute deepinversion's feature distribution regularization
nch = input[0].shape[1]
mean = input[0].mean([0, 2, 3])
var = input[0].permute(1, 0, 2, 3).contiguous().view(
[nch, -1]).var(1, unbiased=False)
# forcing mean and variance to match between two distributions
# other ways might work better, i.g. KL divergence
r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
module.running_mean.data - mean, 2)
self.r_feature = r_feature
# must have no output
def close(self):
self.hook.remove()
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(64)
# regist hooks
self.loss_r_feature_layers = []
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
self.loss_r_feature_layers.append(
DeepInversionFeatureHook(module))
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
loss_r_feature = sum([mod.r_feature
for (idx, mod) in enumerate(self.loss_r_feature_layers)])
return x, loss_r_feature
net = nn.DataParallel(MyNet().cuda())
# net = MyNet().cuda()
output, extra_loss = net(torch.randn(512, 3, 32, 32).cuda())
print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())
# loss = F.mse_loss(output, torch.ones_like(output))
# print('=> mse loss:', loss.size(), loss.device, loss)
print('=> output:', output.size(), output.device, output.grad_fn)
loss = extra_loss.sum()
loss.backward()
net.zero_grad()
Then, I got
=> extra_loss: torch.Size([8]) cuda:0 tensor([11.8416, 11.8376, 11.8413, 11.8440, 11.8397, 11.8390, 11.8433, 11.8395],
device='cuda:0', grad_fn=<GatherBackward>)
=> extra_loss sum: torch.Size([]) cuda:0 tensor(94.7259, device='cuda:0', grad_fn=<SumBackward0>)
=> output: torch.Size([512, 64, 32, 32]) cuda:0 <torch.autograd.function.GatherBackward object at 0x7f246222cac8>
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-34c98315eba0> in <module>
9 loss = extra_loss.sum()
10
---> 11 loss.backward()
12 net.zero_grad()
/opt/conda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
196 products. Defaults to ``False``.
197 """
--> 198 torch.autograd.backward(self, gradient, retain_graph, create_graph)
199
200 def register_hook(self, hook):
/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
98 Variable._execution_engine.run_backward(
99 tensors, grad_tensors, retain_graph, create_graph,
--> 100 allow_unreachable=True) # allow_unreachable flag
101
102
RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:3 (validate_outputs at ../torch/csrc/autograd/engine.cpp:561)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6c (0x7f23ac19e6fc in /opt/conda/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x27dacf4 (0x7f241d705cf4 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #2: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x4f5 (0x7f241d7077b5 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x4df (0x7f241d70971f in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xc0 (0x7f241d6febe0 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x50 (0x7f242110cfb0 in /opt/conda/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xc819d (0x7f245fa7b19d in /opt/conda/bin/../lib/libstdc++.so.6)
frame #7: <unknown function> + 0x76db (0x7f24631ab6db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #8: clone + 0x3f (0x7f2462ed488f in /lib/x86_64-linux-gnu/libc.so.6)
This bug is very wired. The first 2 iterations just works without any error. However, after the first 2 iterations, you can see this bug forever.