Inherit from autograd.Function

I’m implementing a reverse gradient layer and I ran into this unexpected behavior when I used the code below:

import random
import torch
import torch.nn as nn
from torch.autograd import Variable

class ReverseGradient(torch.autograd.Function):

    def __init__(self):
        super(ReverseGradient, self).__init__()

    def forward(self, x):
        return x

    def backward(self, x):
        return -x

class ReversedLinear(nn.Module):
    
    def __init__(self):
        super(ReversedLinear, self).__init__()
        self.linear = nn.Linear(1, 1)
        self.reverse = ReverseGradient()

    def __call__(self, x):
        return self.reverse(self.linear(x)).sum(0).squeeze(0)
        # return self.linear(x).sum(0).squeeze(0)

input1 = torch.rand(1, 1)
input2 = torch.rand(2, 1)

rl = ReversedLinear()

input1 = Variable(input1)
output1 = rl(input1)
input2 = Variable(input2)
output1 += rl(input2)
output1.backward()

I get a size doesn’t match error when I run the code above, however, creating a new instance of ReverseGradient() with every forward prop seems to solve the problem. I just want to understand better how the autograd.Function class works.

This is not directly relevant to the issue you’re seeing, but it’s important to note:
Certain parts of torch.autograd either currently do or will in the future assume that the gradients returned by Functions are correct (i.e., equal to the mathematical derivative). If you want to do things that violate this assumption, that’s fine – but they should be implemented as gradient hooks (var.register_hook) which can arbitrarily modify gradient values, not as Functions.

3 Likes

Yes, you should create new instances of functions each time you use them. But as James suggested, it might make more sense to use gradients hooks in your case.

1 Like