Hi,
I am trying register_forward_hook
and register_full_backward_hook
for InvertedResidual
block. However, I got an error caused by residual connection result += input
with a message:
File "/home/usr/anaconda3/envs/pytorch-1.10-nightly/lib/python3.8/site-packages/torchvision/models/mobilenetv3.py", line 127, in forward
result += input
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
I could modify the code from result += input
to result = result + input
and the error gone. However, I would like to know if there is another way to workaround instead of changing the code. I think residual connections are prevalent and most of them are coded like result += input
.
I paste my testing code here:
import torch
from torch import nn
from functools import partial
from torchvision.models.mobilenetv3 import InvertedResidualConfig
from torchvision.models.mobilenetv3 import InvertedResidual
def fwd_hook(m, x, y):
print("x shape: ", x[0].shape)
print("y shape: ", y.shape)
def bwd_hook(m, gx, gy):
print("gx shape: ", gx[0].shape)
print("gy shape: ", gy[0].shape)
def add_hooks(m_):
if type(m_) in [nn.Conv2d, nn.BatchNorm2d, nn.ReLU]:
_handler = m_.register_forward_hook(fwd_hook)
_handler = m_.register_full_backward_hook(bwd_hook)
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
bneck_conf = partial(InvertedResidualConfig, width_mult=1)
cnf = bneck_conf(16, 3, 16, 16, False, "RE", 1, 1)
self.block = InvertedResidual(cnf, norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(16, 3),
)
def forward(self, x):
x = self.block(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
model = Model()
model.train()
for m in model.modules():
if hasattr(m, "inplace"):
m.inplace = False
model.apply(add_hooks)
criterion = torch.nn.CrossEntropyLoss()
input_size = (1, 16, 112, 112)
x = torch.zeros(input_size, requires_grad=True)
x.to(model.parameters().__next__().device)
y = model(x)
loss = criterion(y, torch.zeros(input_size[0], dtype=int))
loss.backward()
Thanks!