Cannot override torch.round() after upgrading to the latest pytorch version

I tried to override the torch.round() function and it worked well before I upgraded my pytorch version to 0.2.0. I don’t know what’s going wrong with it and need some help!

###-------------------------------------------------------------

import torch

class RoundNoGradient(torch.autograd.Function):
    @staticmethod
    def forward(x):
        return x.round()
    @staticmethod
    def backward(g):
        return g 


m = torch.nn.Conv2d(16, 10, 3, stride=2)
l = torch.nn.L1Loss()
input = torch.autograd.Variable(torch.randn(20, 16, 50, 10), requires_grad=True)
x = RoundNoGradient()(input)
y = m(x)
output = l(x, torch.autograd.Variable(torch.randn(20, 16, 50, 10)))
output.backward()
print(x.grad)

###-------------------------------------------------------------------------------
And it reports:

  File "test.py", line 18, in <module>
    output.backward()
  File "/usr/local/lib/python2.7/dist-packages/torch/autograd/variable.py", line 156, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/usr/local/lib/python2.7/dist-packages/torch/autograd/__init__.py", line 98, in backward
    variables, grad_variables, retain_graph)
TypeError: forward() takes exactly 1 argument (2 given)

If your forward() call in your function is labelled with @staticmethod it needs to be an new-style autograd function which takes a ctx as the first argument:

class RoundNoGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.round()
    @staticmethod
    def backward(ctx, g):
        return g 

then to call it:

input = torch.autograd.Variable(torch.randn(20, 16, 50, 10), requires_grad=True)
x = RoundNoGradient.apply(input)

However, you won’t be able to get x.grad because it’s an intermediate variable. You can access input.grad though:

print(input.grad)

Thx for your instructive answer!!