Problem with autograd.Function

I used to adopt a self-defined Function using PyTorch 1.1.0.
Now a newer PyTorch (1.10.0) said the function must be used with a static method.
I learn the tutorial but still don’t know how to make the function maintain a variable (here is self.V) during the forward and backward. Anyone could help?

class Exclusive(autograd.Function):
    def __init__(self, V):
        super(Exclusive, self).__init__()
        self.V = V

    def forward(self, inputs, targets):
        self.save_for_backward(inputs, targets)
        outputs = inputs.mm(self.V.t())
        return outputs

    def backward(self, grad_outputs):
        inputs, targets = self.saved_tensors
        grad_inputs = grad_outputs.mm(self.V) if self.needs_input_grad[0] else None
        for x, y in zip(inputs, targets):
            self.V[y] = F.normalize( (self.V[y] + x) / 2, p=2, dim=0)
        return grad_inputs, None
1 Like

Hi 3!

As you’ve recognized, because pytorch now requires the forward() and
backward() functions of an autograd.Function to be static, you can no
longer store any state (your self.V) in an instance of your Exclusive.

For this reason, pytorch has added a context argument (ctx) to forward()
and backward() that is used to save information from the forward() call
for use in the backward() call.

So you could do something like:

    @staticmethod
    def forward (ctx, V, inputs, targets):
        ctx.save_for_backward (V, inputs, targets)
        ...
    
    @staticmethod
    def backward (ctx, grad_outputs):
        V, inputs, targets = ctx.saved_tensors
        ...

You would have to store V somewhere else in your model, and then pass it
in when you call, e.g., Exclusive.apply (my_model.V, inputs, targets).

Best.

K. Frank

Dear Frank,

Thank you so much for your kind help!
I think besides the ctx, I need to set the return of backward() to be

        return grad_inputs, None, None 

Below is the full code:

class Exclusive(autograd.Function):



    @staticmethod
    def forward(ctx, V, inputs, targets):
        ctx.save_for_backward(V, inputs, targets)
        outputs = inputs.mm(V.t())
        return outputs

    @staticmethod
    def backward(ctx, grad_outputs):
        V, inputs, targets = ctx.saved_tensors
        grad_inputs = grad_outputs.mm(V) if ctx.needs_input_grad[0] else None
        for x, y in zip(inputs, targets):
            V[y] = F.normalize( (V[y] + x) / 2, p=2, dim=0)
        return grad_inputs, None, None

Everything is OK now!
Thank you!