The previous code gives me TypeError: forward() missing 1 required positional argument: 'target'. This seems to happen because it takes ctx as a. I tried to change ctx to self, it worked but in this case I receive nothing when I call self.saved_variables or self.saved_tensors.
I think I must be doing some rookie mistake here, by I don’t know where.
Functions are used slightly differently compared to nn.Modules.
Your code sample should be:
a = Variable(torch.randn(4) + 10)
b = Variable(torch.randn(4) + 10)
# You never instantiate a `Function` yourself !
loss = GammaLoss.apply(a, b)
print(loss)
For convenience, you can use the following nn.Module to use it as other loss functions and that your can use the same way as you did in your code sample:
class GammaLossMod(nn.Module):
def __init__(self):
# EDITED: argument order was wrong
super(GammaLossMod, self).__init__()
def forward(self, out, target):
return GammaLoss.apply(out, target)