torch.autograd.Function overwrite

hi, when inherit autograd.Function to define an op e.g.

class MyFunc(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        some_inplace_operation1(x)
        return x
   @staticmethod
    @once_differentiable
    def backward(ctx, dy):
        some_inplace_operation2(dy)
        return dy

in some cases, there will be total error gradient (depend on the followed op). What’s the reason? A concrete example is here.

Hi,

You are not allowed to modify dy inplace in the backward function. If you do so, the computed gradients can be arbitrarily wrong :slight_smile:

1 Like

Could you please explain more why the computed gradients can be arbitrarily wrong and is there a solution to safely modify dy because this can save memory and improve efficiency.

Because this Tensor is supposed to contain the value of the gradient for a given Tensor in the forward pass.
But this value might be needed in other places. If you modify it inplace, these other places will see wrong values.

Where may it be used, assume that op following this op has it as the only one input.

In general, it should never :smiley: In some limited cases, it might work, but this is not guaranteed.
Also your forward function is missing a call to mark_dirty(). See the doc about extending the autograd.

I did call mark_dirty :blush: Is there a way to find out what kind of cases it should be?

Could you give a full code sample that reproduce the error you’re seeing?

You can try the resnet18 with the op InPlaceABN at this commit . I train it on cifar10 using code here and pytorch 1.1.0, its validation accuracy is always lower than that of resnet18 with standard BatchNorm2d and there is a significant accuray drop during training:


How to fix this problem without modify the op since it is more memory efficient?

That’s a lot of code… that would really help if you had a small code sample that we can run and see what it is doing easily.

Are you sure the accuracy drop comes from the autograd? Can you try temporarily adding a dy.clone() at the beginning of your function and see if you still have the same behavior?
Such behavior in the loss can also come from instability of your new operation.

Sorry for the plenty of code, Let me make a minimal one later. I have tried that(adding a dy.clone() at the beginning of backward) and there was no accuracy drop and a higher final accuracy(5 percent point higher).

Hi, here is a minimal sample code that can reproduce the error. Please check it.
I find out that the error is related to at out = out + self.shortcut(x) when len(self.shortcut) == 0 and out is outputed by the inplace op, what’s the reason behind it?

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
            self.bn2 = nn.BatchNorm2d(planes)
        else:
            #self.n2 = nn.BatchNorm2d(planes) #use name `n2` to keep it `nn.BatchNorm2d` when `len(self.shortcut) == 0`
                                             #can avoid the gradient error, but why?
            self.bn2 = nn.BatchNorm2d(planes) #in this case, 'bn2' will be repaced with `my_func*` when `len(self.shortcut) == 0`
                                             # and will cause the gradient error.


    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.conv2(out)
        if len(self.shortcut):
            out = self.bn2(out)
        else:
            out = self.bn2(out)
            #out = self.n2(out)
        out = out + self.shortcut(x)
        out = F.relu(out, inplace=True)
        return out

There is no custom autograd Function in this code right?
This code looks fine, what is the exact issue you’re seeing with it?

The full same code is: https://gist.github.com/knsong/1b2aeaf1cdfef28d52138df5d05cd949 , and my custom autograd Function is there, and if you execute it with python backward_inplace.py , you will see the gradient error between the two version(inplace or no inplace) of the Function, like grad x error max: tensor(0.8613).

As I said, you are not allowed to modify the grad inplace. That will make what the hooks (and potentially other function) see. So this is expected that you get wrong gradients if you modify the grad inplace in the backward function !

It does’t always result in wrong gradients as I expalined in the sample code, only at here it will do. So I will greatly appreciate if there is some details about what happened here?

We would need to dig into how each of these functions are implemented.
In particular we need to find out what uses this particular gradient buffer and in which order.

My point was more that it is a dangerous thing to do. Even if you find why it works in this case, for this version of pytorch, your code might silently be broken by the next version of pytorch if we change some internal implementations of Functions or the autograd engine :confused: