I tried to override the torch.round() function and it worked well before I upgraded my pytorch version to 0.2.0. I don’t know what’s going wrong with it and need some help!
###-------------------------------------------------------------
import torch
class RoundNoGradient(torch.autograd.Function):
@staticmethod
def forward(x):
return x.round()
@staticmethod
def backward(g):
return g
m = torch.nn.Conv2d(16, 10, 3, stride=2)
l = torch.nn.L1Loss()
input = torch.autograd.Variable(torch.randn(20, 16, 50, 10), requires_grad=True)
x = RoundNoGradient()(input)
y = m(x)
output = l(x, torch.autograd.Variable(torch.randn(20, 16, 50, 10)))
output.backward()
print(x.grad)
###-------------------------------------------------------------------------------
And it reports:
File "test.py", line 18, in <module>
output.backward()
File "/usr/local/lib/python2.7/dist-packages/torch/autograd/variable.py", line 156, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
File "/usr/local/lib/python2.7/dist-packages/torch/autograd/__init__.py", line 98, in backward
variables, grad_variables, retain_graph)
TypeError: forward() takes exactly 1 argument (2 given)