Using autograd.Function to compute gradient for a function

I am learning autograd and I was following the tutorials extending pytorch.

I would like to compute gradients wrt x and y for the following function.

import torch
from torch.autograd import Function

class Func(Function):
    
    @staticmethod
    def forward(x, v):
        a = torch.tanh(x)
        o = torch.matmul(a, v)
        return o, a
    
    @staticmethod
    def setup_context(ctx, inputs, outputs):
        x, v = inputs
        o, a = outputs
        ctx.save_for_backward(x, v, a)
        
    @staticmethod
    def backward(ctx, grad_o, grad_a):
        x, v, a = ctx.saved_tensors

        dL_da = torch.matmul(grad_o, v.transpose(1, 0))
        dL_dv = torch.matmul(a.transpose(1, 0), grad_o)
        
        dL_dx = (dL_da / ((torch.cosh(x)) ** 2))
        return dL_dx, dL_dv

When I run gradcheck as

x = torch.randn(1, 3)

func = Func.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = torch.randn(3, 2, requires_grad=True, dtype=torch.double)
test = gradcheck(func, (x, y), eps=1e-6, atol=1e-4)

it throws me the following traceback:

GradcheckError: Jacobian mismatch for output 1 with respect to input 0,
numerical:tensor([[0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7561, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.9592, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.9947, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9674, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4168]], dtype=torch.float64)
analytical:tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], dtype=torch.float64)

I am not sure what I am doing wrong here. Can someone help me figure out?

Other notes

I tried out simple functions - tanh and matmul and they seem to work fine:

class Mm(Function):
    
    @staticmethod
    def forward(x, y):
        return torch.matmul(x, y)
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, y = inputs
        ctx.save_for_backward(x, y)
        
    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        dx = torch.matmul(grad_output, y.transpose(1, 0))
        dy = torch.matmul(x.transpose(1, 0), grad_output)
        return dx, dy

class Tanh(Function):
    
    @staticmethod
    def forward(x):
        return torch.tanh(x)
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        ctx.save_for_backward(x)
        
    @staticmethod
    def backward(ctx, grad_y):
        x, = ctx.saved_tensors
        return grad_y * (1 / ((torch.cosh(x)) ** 2))
    
mm = Mm.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = torch.randn(3, 2, requires_grad=True, dtype=torch.double)
test = gradcheck(mm, (x, y), eps=1e-6, atol=1e-4)

tanh = Tanh.apply
x = torch.randn(1, 3, requires_grad=True, dtype=torch.double)
test = gradcheck(tanh, x, eps=1e-6, atol=1e-4)

Shouldn’t the .apply method contain arguments? That’s probably why you get a zero tensor for your analytical results. You should probably defined func as,

def func(x, v):
    return Func.apply(x, v)

The doc says we can do it either way:

# Option 1: alias
linear = LinearFunction.apply

# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

Fair enough.

I’ve just realized you haven’t used grad_a in your code, could that might be why you get a zero tensor?

1 Like

That’s another question which I have - when I can compute grad_a from grad_o, should I use either the computed value or grad_a from output. If I am not wrong, dL_da should be equal to grad_a.

I tried the following variant using grad_a but still getting same error.

import torch
from torch.autograd import gradcheck
from torch.autograd import Function

class Func(Function):
    
    @staticmethod
    def forward(x, v):
        a = torch.tanh(x)
        o = torch.matmul(a, v)
        return o, a

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        x, v = inputs
        o, a = outputs
        ctx.save_for_backward(x, v, a)
        
    @staticmethod
    def backward(ctx, grad_o, grad_a):
        x, v, a = ctx.saved_tensors

        dL_dv = torch.matmul(a.transpose(1, 0), grad_o)
        
        dL_dx = (grad_a / ((torch.cosh(x)) ** 2))
        return dL_dx, dL_dv


x = torch.randn(1, 3)

func = Func.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = torch.randn(3, 2, requires_grad=True, dtype=torch.double)
test = gradcheck(func, (x, y), eps=1e-6, atol=1e-4)

Are you sure your backward formula is correct? There should be a (1 - tanh(x)**2) term, as that’s the derivative of tanh(x) w.r.t x.

The matrix cookbook should be able to help with the linear algebra.


    def backward(ctx, grad_o, grad_a):
        x, v, a = ctx.saved_tensors

        dL_da = grad_a + torch.matmul(grad_o, v.transpose(1, 0))
        dL_dv = torch.matmul(a.transpose(1, 0), grad_o)

        dL_dx = dL_da * (1 / (torch.cosh(x) ** 2))
        return dL_dx, dL_dv

AlphaBetaGamma has the right idea, you missed the grad_a. Technically though, your backward isn’t wrong if you never use a in any gradient computation. I would guess that it is only returned here so you could save it for backward (usually you would then have a wrapper function that filters out the a from the outputs).

The formula is detected to be wrong by gradcheck because it try to backprop from a back to the inputs, and since a is indeed a differentiable function of the inputs, the numerical and analytical gradients will differ

1 Like