I’m trying to flip some of the channels in my model as follows,
(Notice that the input size is like N*T,C,H,W)
class Flip(nn.Module):
def __init__(self, t, inplace=True):
super(Flip, self).__init__()
self.t = t
self.inplace = inplace
def forward(self, x):
x = self.flip(x, self.t, self.inplace)
return x
@staticmethod
def flip(x, t, inplace=True):
if inplace:
out = InplaceFlip.apply(x, t)
else:
# Input: (NT,C,H,W)
nt, c, h, w = x.size()
n = nt // t
x = x.view(n, t, c, h, w)
buffer = torch.zeros(n, t // 2, c, h, w)
# even
buffer = x[:, 1::2]
buffer = buffer.flip(2)
x[:, 1::2] = buffer
x = x.view(nt, c, h, w)
return x
class InplaceFlip(torch.autograd.Function):
@staticmethod
def forward(ctx, x, t):
ctx.t_ = t
# Input: (NT,C,H,W)
nt, c, h, w = x.size()
n = nt // t
x = x.view(n, t, c, h, w)
buffer = x.data.new(n, t // 2, c, h, w).zero_()
# even
buffer = x[:, 1::2]
buffer = buffer.flip(2)
x[:, 1::2] = buffer
x = x.view(nt, c, h, w)
return x
@staticmethod
def backward(ctx, grad_output):
t = ctx.t_
# Input: (NT,C,H,W)
nt, c, h, w = grad_output.size()
n = nt // t
grad_output = grad_output.view(n, t, c, h, w)
buffer = grad_output.data.new(n, t // 2, c, h, w).zero_()
# even
buffer = grad_output[:, 1::2]
buffer = buffer.flip(2)
grad_output[:, 1::2] = buffer
grad_output = grad_output.view(nt, c, h, w)
return grad_output, None
It works normally when testing as follows,
flip1 = Flip(t=8, inplace=False)
flip2 = Flip(t=8, inplace=True)
# test forward
with torch.no_grad():
for i in range(10):
x = torch.rand(2 * 8, 64, 7, 7)
y1 = flip1(x)
y2 = flip2(x)
assert torch.norm(y1 - y2).item() < 1e-5
# test backward
with torch.enable_grad():
for i in range(10):
x1 = torch.rand(2 * 8, 64, 7, 7)
x1.requires_grad_()
x2 = x1.clone()
y1 = flip1(x1)
y2 = flip2(x2)
grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
assert torch.norm(grad1 - grad2).item() < 1e-5
flip1.cuda()
flip2.cuda()
# test forward
with torch.no_grad():
for i in range(10):
x = torch.rand(2 * 8, 64, 7, 7).cuda()
y1 = flip1(x)
y2 = flip2(x)
assert torch.norm(y1 - y2).item() < 1e-5
# test backward
with torch.enable_grad():
for i in range(10):
x1 = torch.rand(2 * 8, 64, 7, 7).cuda()
x1.requires_grad_()
x2 = x1.clone()
y1 = flip1(x1)
y2 = flip2(x2)
grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
assert torch.norm(grad1 - grad2).item() < 1e-5
However, when I tried to insert such module in my model, like ResNet-50, it occured some error as follows,
Traceback (most recent call last):
File "main.py", line 496, in <module>
main()
File "main.py", line 279, in main
train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
File "main.py", line 338, in train
loss.backward()
File "/home/likunchang/.local/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/likunchang/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [96, 2048, 7, 7]], which is output 0 of ReluBackward1, is at version 2; expected version 1 instead. Hint: the backtrace further above sho
ws the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Maybe trying to only flip the even channels results in the problem, but I really need to realize such operation. How can I fix it? Or do you have better suggestion to realize it? Hopefully to your reply!!!
BTW, my codes for such operation refer to https://github.com/mit-han-lab/temporal-shift-module/blob/832a758f0c1e4a835cb0a47d957eff776d35dd91/ops/temporal_shift.py#L47