No grad accumulator error for custom backward

I am trying to write my own backward function to calculate gradients. However, I get the following error: “RuntimeError: No grad accumulator for a saved leaf!”. This error can be reproduced using the code below. Any help is greatly appreciated! Thank you.

import torch
from torch.autograd import gradcheck

class CustomAutoG(torch.autograd.Function):
    #REF: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
    #REF: https://pytorch.org/docs/1.7.1/notes/extending.html#extending-torch-autograd
    @staticmethod
    def forward(ctx, x1, x2, y, z):
        x1 = x1.view(1, -1)
        x2 = x2.view(1, -1)
        z = z.view(z.size(0), z.size(1), 1)

        x1_floor = torch.floor(x1)
        x1_ceil = torch.ceil(x1)

        out_floor = torch.zeros_like(x1).repeat(z.size(1), 1)
        out_ceil = torch.zeros_like(out_floor)
        for i,ysmpl in enumerate(y):
            mask = (x1_floor == ysmpl).double()
            out_floor.add_(z[i]*mask)
            mask = (x1_ceil == ysmpl).double()
            out_ceil.add_(z[i]*mask)

        out = out_floor*(x1_ceil-x1) + out_ceil*(x1-x1_floor)
        ctx.save_for_backward(out_floor, out_ceil, out, x2)
        return x2*out

    @staticmethod
    def backward(ctx, grad_output):
        out_floor, out_ceil, out, x2 = ctx.saved_tensors
        grad_x1 = torch.sum(grad_output*x2*(out_ceil-out_floor), dim=0)
        grad_x2 = torch.sum(grad_output*out, dim=0)
        return grad_x1, grad_x2, None, None

x1 = torch.nn.Parameter(torch.DoubleTensor([20.1, 19.3]), requires_grad=True)
x2 = torch.nn.Parameter(torch.DoubleTensor([0.001, 0.002]), requires_grad=True)
y = torch.DoubleTensor([18,19,20,21])
z = torch.DoubleTensor([[0.1,0.2],[0.3,0.4],[0.5,0.6],[0.7,0.8]])

input = (x1,x2,y,z)
test = gradcheck(CustomAutoG.apply, input, eps=1e-6, atol=1e-4)
print('gradcheck:', test)

Hi,

The problem is that you’re using ctx.save_for_backawrd() for things that are neither inputs or outputs.
You can just save these on the ctx. and it will work:

import torch
from torch.autograd import gradcheck

class CustomAutoG(torch.autograd.Function):
    #REF: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
    #REF: https://pytorch.org/docs/1.7.1/notes/extending.html#extending-torch-autograd
    @staticmethod
    def forward(ctx, x1, x2, y, z):
        x1 = x1.view(1, -1)
        x2 = x2.view(1, -1)
        z = z.view(z.size(0), z.size(1), 1)

        x1_floor = torch.floor(x1)
        x1_ceil = torch.ceil(x1)

        out_floor = torch.zeros_like(x1).repeat(z.size(1), 1)
        out_ceil = torch.zeros_like(out_floor)
        for i,ysmpl in enumerate(y):
            mask = (x1_floor == ysmpl).double()
            out_floor.add_(z[i]*mask)
            mask = (x1_ceil == ysmpl).double()
            out_ceil.add_(z[i]*mask)

        out = out_floor*(x1_ceil-x1) + out_ceil*(x1-x1_floor)
        ctx.x2 = x2
        ctx.out_floor = out_floor
        ctx.out_ceil = out_ceil
        ctx.out = out
        return x2*out

    @staticmethod
    def backward(ctx, grad_output):
        grad_x1 = torch.sum(grad_output*ctx.x2*(ctx.out_ceil-ctx.out_floor), dim=0)
        grad_x2 = torch.sum(grad_output*ctx.out, dim=0)
        return grad_x1, grad_x2, None, None

x1 = torch.nn.Parameter(torch.DoubleTensor([20.1, 19.3]), requires_grad=True)
x2 = torch.nn.Parameter(torch.DoubleTensor([0.001, 0.002]), requires_grad=True)
y = torch.DoubleTensor([18,19,20,21])
z = torch.DoubleTensor([[0.1,0.2],[0.3,0.4],[0.5,0.6],[0.7,0.8]])

input = (x1,x2,y,z)
test = gradcheck(CustomAutoG.apply, input, eps=1e-6, atol=1e-4)
print('gradcheck:', test)

This works! Thank you.

Although this solution works, the documentation said that we must use ctx.save_for_backward() to save tensors for backward calculation. So, is there any guidance on when and why we may save a tensor directly on ctx or using ctx.save_for_backward()? I don’t clearly understand the need for ctx.save_for_backward() if we can simply save tensors directly to ctx.

Hi,

You should use save_for_backward only for input/output Tensors. All other ones can be saved on the context.

1 Like