"failed to compute its gradient" when using torch.flip

I’m trying to flip some of the channels in my model as follows,
(Notice that the input size is like N*T,C,H,W)

class Flip(nn.Module):
    def __init__(self, t, inplace=True):
        super(Flip, self).__init__()
        self.t = t
        self.inplace = inplace

    def forward(self, x):
        x = self.flip(x, self.t, self.inplace)
        return x

    @staticmethod
    def flip(x, t, inplace=True):
        if inplace:
            out = InplaceFlip.apply(x, t)
        else:
            # Input: (NT,C,H,W)
            nt, c, h, w = x.size()
            n = nt // t

            x = x.view(n, t, c, h, w)

            buffer = torch.zeros(n, t // 2, c, h, w)
            
            # even
            buffer = x[:, 1::2]
            
            buffer = buffer.flip(2)

            x[:, 1::2] = buffer 
            x = x.view(nt, c, h, w)
        return x


class InplaceFlip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, t):
        ctx.t_ = t

        # Input: (NT,C,H,W)
        nt, c, h, w = x.size()
        n = nt // t

        x = x.view(n, t, c, h, w)

        buffer = x.data.new(n, t // 2, c, h, w).zero_()
        
        # even
        buffer = x[:, 1::2]
        
        buffer = buffer.flip(2)

        x[:, 1::2] = buffer
        x = x.view(nt, c, h, w)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        t = ctx.t_

        # Input: (NT,C,H,W)
        nt, c, h, w = grad_output.size()
        n = nt // t

        grad_output = grad_output.view(n, t, c, h, w)

        buffer = grad_output.data.new(n, t // 2, c, h, w).zero_()
        
        # even
        buffer = grad_output[:, 1::2]
        
        buffer = buffer.flip(2)

        grad_output[:, 1::2] = buffer
        grad_output = grad_output.view(nt, c, h, w)
        return grad_output, None

It works normally when testing as follows,

flip1 = Flip(t=8, inplace=False)
flip2 = Flip(t=8, inplace=True)
# test forward
with torch.no_grad():
    for i in range(10):
        x = torch.rand(2 * 8, 64, 7, 7)
        y1 = flip1(x)
        y2 = flip2(x)
        assert torch.norm(y1 - y2).item() < 1e-5
# test backward
with torch.enable_grad():
    for i in range(10):
        x1 = torch.rand(2 * 8, 64, 7, 7)
        x1.requires_grad_()
        x2 = x1.clone()
        y1 = flip1(x1)
        y2 = flip2(x2)
        grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
        grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
        assert torch.norm(grad1 - grad2).item() < 1e-5
flip1.cuda()
flip2.cuda()
# test forward
with torch.no_grad():
    for i in range(10):
        x = torch.rand(2 * 8, 64, 7, 7).cuda()
        y1 = flip1(x)
        y2 = flip2(x)
        assert torch.norm(y1 - y2).item() < 1e-5
# test backward
with torch.enable_grad():
    for i in range(10):
        x1 = torch.rand(2 * 8, 64, 7, 7).cuda()
        x1.requires_grad_()
        x2 = x1.clone()
        y1 = flip1(x1)
        y2 = flip2(x2)
        grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
        grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
        assert torch.norm(grad1 - grad2).item() < 1e-5

However, when I tried to insert such module in my model, like ResNet-50, it occured some error as follows,

Traceback (most recent call last):
  File "main.py", line 496, in <module>
    main()
  File "main.py", line 279, in main
    train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
  File "main.py", line 338, in train
    loss.backward()
  File "/home/likunchang/.local/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/likunchang/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [96, 2048, 7, 7]], which is output 0 of ReluBackward1, is at version 2; expected version 1 instead. Hint: the backtrace further above sho
ws the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Maybe trying to only flip the even channels results in the problem, but I really need to realize such operation. How can I fix it? Or do you have better suggestion to realize it? Hopefully to your reply!!!

BTW, my codes for such operation refer to https://github.com/mit-han-lab/temporal-shift-module/blob/832a758f0c1e4a835cb0a47d957eff776d35dd91/ops/temporal_shift.py#L47

I solve the problem…
I forgot to add .data so that the gradient was ignored…

class Flip(nn.Module):
    def __init__(self, t, inplace=True):
        super(Flip, self).__init__()
        self.t = t
        self.inplace = inplace

    def forward(self, x):
        x = self.flip(x, self.t, self.inplace)
        return x

    @staticmethod
    def flip(x, t, inplace=True):
        if inplace:
            out = InplaceFlip.apply(x, t)
        else:
            # Input: (NT,C,H,W)
            nt, c, h, w = x.size()
            n = nt // t

            x = x.view(n, t, c, h, w)

            out = torch.zeros_like(x)
            
            # even
            out[:, 1::2] = x[:, 1::2]
            
            out = out.flip(2)

            out[:, 0::2] = x[:, 0::2]
            out = out.view(nt, c, h, w)
        return out


class InplaceFlip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, t):
        ctx.t_ = t

        # Input: (NT,C,H,W)
        nt, c, h, w = x.size()
        n = nt // t

        x = x.view(n, t, c, h, w)

        buffer = x.data.new(n, t // 2, c, h, w).zero_()
        
        # even
        buffer = x[:, 1::2]
        
        buffer = buffer.flip(2)

        x[:, 1::2] = buffer
        x = x.view(nt, c, h, w)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        t = ctx.t_

        # Input: (NT,C,H,W)
        nt, c, h, w = grad_output.size()
        n = nt // t

        grad_output = grad_output.view(n, t, c, h, w)

        buffer = grad_output.data.new(n, t // 2, c, h, w).zero_()
        
        # even
        buffer = grad_output[:, 1::2]
        
        buffer = buffer.flip(2)

        grad_output[:, 1::2] = buffer
        grad_output = grad_output.view(nt, c, h, w)
        return grad_output, None

Hi,

Note that you should never use .data. You can use .detach() if you want a Tensor with the same content but that does not track gradients. Or with torch.no_grad(): if you want to locally disable the autograd.