How to override the gradients for parameters?

I want to do the following operations:
import torch.nn as nn
import torch.nn.functional as F
During forward:
I defined a new conv layer:
class new_conv(nn.conv2d):
>def forward(self, input):
>>self.weight = self.weight.round()
output = F.conv2d(input, self.weight, self.bias…)

During backward:
I want to override the gradient of self.weight.round() with identity mapping, that is to replace the gradient(self.weight.round()) with gradient(self.weight).

One possible way I know is to use register_backward_hook(), however I don’t know how to apply it in my case. In Tensorflow, simply using the function G.gradient_override_map({“Round”: “Identity”}) works well.

1 Like

Maybe this will work for you:

import torch 

class RoundNoGradient(torch.autograd.Function):
    def forward(ctx, x):
        return x.round()
    def backward(ctx, g):
        return g 

m = torch.nn.Conv2d(16, 33, 3, stride=2)
l = torch.nn.L1Loss()
input = torch.autograd.Variable(torch.randn(20, 16, 50, 100), requires_grad=True)
x = RoundNoGradient.apply(input)
#x = input.round()
y = m(x)
output = l(x, torch.autograd.Variable(torch.randn(20, 16, 50, 100)))
1 Like

Hi ezyang,
Thanks for your reply! But it reports "AttributeError: type object ‘RoundNoGradient’ has no attribute ‘apply’ " when I ran your sample code. And do you have any idea on that?

And is there any way to treat the gradient as “no connection” ?

That’s probably because the syntax I’ve written above is for PyTorch HEAD. If you can’t upgrade, try RoundNoGradient()(input) instead.

What do you mean by “no connection”?

Thx, it works well now. “No connection” means I manually set the gradient to None, which is mathematically equivalent to zero but faster since it will never construct large zero matrices. For example, in Tensorflow I can do like this:

def __RoundGrad(_, grad):
return None