Custom function missing argument

Hi, I’m quite new to pytorch.
I’ve been trying to write a custom loss function. This is my first try:

class GammaLoss(Function):
    
    @staticmethod
    def forward(ctx, out, target):
        ctx.save_for_backward(out, target)
        #return -(torch.log(target) * (self.shape - 1) - (target / out) * (self.shape - 1)  - torch.log((out / target)) *  self.shape)
        ratio = torch.div(out, target)
        output = torch.log(ratio)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        out, target = ctx.saved_variables
        dout = - (1 / (out * out))((ctx.shape - 1) * target - ctx.shape * out)
        return grad_output.mm(dout), None, None


a = Variable(torch.randn(4) + 10)
b = Variable(torch.randn(4) + 10)
loss = GammaLoss()
print(loss(a, b))

The previous code gives me TypeError: forward() missing 1 required positional argument: 'target'. This seems to happen because it takes ctx as a. I tried to change ctx to self, it worked but in this case I receive nothing when I call self.saved_variables or self.saved_tensors.

I think I must be doing some rookie mistake here, by I don’t know where.

Hi,

Functions are used slightly differently compared to nn.Modules.
Your code sample should be:

a = Variable(torch.randn(4) + 10)
b = Variable(torch.randn(4) + 10)
# You never instantiate a `Function` yourself !
loss = GammaLoss.apply(a, b)
print(loss)

For convenience, you can use the following nn.Module to use it as other loss functions and that your can use the same way as you did in your code sample:

class GammaLossMod(nn.Module):
    def __init__(self):
        # EDITED: argument order was wrong
        super(GammaLossMod, self).__init__()

    def forward(self, out, target):
        return GammaLoss.apply(out, target)
2 Likes

Thank you for the quick response! It works now! But I cant make the second one work, when I do

a = Variable(torch.randn(4) + 10)
b = Variable(torch.randn(4) + 10)
loss = GammaLossMod()
print(loss(a, b))

It raises super() argument 1 must be type, not GammaLossMod, which is bizarre since i’m using python 3.5.2

Ho my bad it’s a typo, the self and GammaLossMod are in the wrong order in my code sample, I’ll fix it.

1 Like

Oh that’s true, I haven’t noticed neither