Backward() in custom layer is not called

What I am going to do is modifying weight in Conv2d after loss.backward() and before optimizer.step(). One solution is to modify weights in corresponding layers after loss.backward(). I just want to make a custom layer to do this keeping train() function clean. Here is the code snippet:

class CustomConv(nn.Conv2d):
    def __init__(self, *kargs, **kwargs):
        super(CustomConv, self).__init__(*kargs, **kwargs)

    def forward(self, input):
        input = WeightModifier(self.weight)(input)
        out = nn.functional.conv2d(
        return out

class WeightModifier(torch.autograd.Function):
    def __init__(self, weight):
        self.weight = weight

    def forward(self, input):
        return input.clone()

    def backward(self, grad):
        # Some op to change self.weight.
        # In other words, change weights in the following conv2d
        # after loss.backward() and before optimizer.step()
        return grad.clone()

hidden = CustomConv()(input_tensor)  # backward not called 
hidden = CustomConv()(hidden)
loss = cal_loss(hidden, gt)

The problem is the backward() of WeightModifier in the first CustomConv is not called (that in second CustomConv is called). I guess the reason is that Pytorch finds that input_tensor does not require gradient and layer WeightModifier does not have any parameters. Is the any method to force or “cheat” Pytorch to execute the backward()?



The problem most likely comes from the fact that you use old style Functions. You can see in the doc how to write new style ones.
Do keep the .clone() in the forward though otherwise we detect that it’s the same Tensor and the Function's backward won’t be called.

1 Like

Hi albanD,
Thanks. I still don’t make it work. As you said, I tried something like this:

class WeightModifier(torch.autograd.Function):
    def forward(ctx, input, weight):
        return input.clone()

    def backward(ctx, grad_output):
        weight, = ctx.saved_tensors
        return grad_output.clone()

which is called by input = WeightModifier.apply(input, self.weight). However Pytorch requires backward() to return gradients for both input and weight, whereas I want it to return gradient for input only.

Sorry to bother you again. Here is another question. In this exampe, input.clamp(min=0) is not recorded in gradient graph. Is it because it’s inside a forward function of autograd.Function? To generalize, any module/function inside a forward function of a Function does not contribute to gradient computation?


The whole point of using a Function is to not use the autograd.
If you want to use the autograd, you can just use a regular python function.

If you don’t return gradients for one of the inputs, you can return None for it.

1 Like