Variable grad is always None when extending autograd

Hi,

I have been extending autograd following the instruction. However, after I call backward function, the gradients of variables are always None even the return of the extended backward function is not None at all.

The extended Op is as follows,

import torch
import _ext.cunnex
import pdb
from torch.autograd import Variable

class SeparableConvolution(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, vertical, horizontal):
        ctx.save_for_backward(input, vertical, horizontal)
        ctx.data_for_backward = input, vertical, horizontal
        intBatches = input.size(0)
        intInputDepth = input.size(1)
        intInputHeight = input.size(2)
        intInputWidth = input.size(3)
        intFilterSize = min(vertical.size(1), horizontal.size(1))
        intOutputHeight = min(vertical.size(2), horizontal.size(2))
        intOutputWidth = min(vertical.size(3), horizontal.size(3))

        assert(intInputHeight - 51 == intOutputHeight - 1)
        assert(intInputWidth - 51 == intOutputWidth - 1)
        assert(intFilterSize == 51)

        assert(input.is_contiguous() == True)
        assert(vertical.is_contiguous() == True)
        assert(horizontal.is_contiguous() == True)

        output = input.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_()

        if input.is_cuda == True:
            _ext.cunnex.SeparableConvolution_cuda_forward(
                input,
                vertical,
                horizontal,
                output
            )

        elif input.is_cuda == False:
            raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED

        # end
        return output
    # end
    @staticmethod
    def backward(ctx, grad_output):
        input, vertical, horizontal = ctx.data_for_backward
        intBatches = input.size(0)
        intInputDepth = input.size(1)
        intInputHeight = input.size(2)
        intInputWidth = input.size(3)
        intFilterSize = min(vertical.size(1), horizontal.size(1))
        intOutputHeight = min(vertical.size(2), horizontal.size(2))
        intOutputWidth = min(vertical.size(3), horizontal.size(3))

        grad_input = input.new().resize_(intBatches, intInputDepth, intInputHeight, intInputWidth).zero_()
        grad_vertical = vertical.new().resize_(intBatches, intFilterSize, intOutputHeight, intOutputWidth).zero_()
        grad_horizontal = horizontal.new().resize_(intBatches, intFilterSize, intOutputHeight, intOutputWidth).zero_()

        if grad_output.is_cuda == True:
            _ext.cunnex.SeparableConvolution_cuda_backward(
                grad_output.data,
                input,
                vertical,
                horizontal,
                grad_input,
                grad_vertical,
                grad_horizontal
            )

        elif grad_output.is_cuda == False:
            raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED

        return Variable(grad_input), Variable(grad_vertical), Variable(grad_horizontal)

    # end
    # end

The detailed forward and backward methods, like ‘SeparableConvolution_cuda_forward’ and ‘SeparableConvolution_cuda_backward’ are implemented in CUDA. I have checked that ‘Variable(grad_input), Variable(grad_vertical)’ and ‘Variable(grad_horizontal)’ are torch tensors. I call this new operation above in a simple test case when GPU is available. However, when I print out gradients of inputs, they are always None.

The simple test case is,

import torch
from torch.autograd import Variable
from util.SeparableConvolution import SeparableConvolution
import math
import pdb

img = Variable(torch.zeros(1, 1, 51, 51), requires_grad=True).cuda()
img[0, 0, 25, 25] = 1
img[0, 0, 23, 24] = 5
v = Variable(torch.zeros(1, 51, 1, 1), requires_grad=True).cuda()
v[0, 25, 0, 0] = 2
v[0, 23, 0, 0] = 7
h = Variable(torch.zeros(1, 51, 1, 1), requires_grad=True).cuda()
h[0, 25, 0, 0] = 3
h[0, 24, 0, 0] = 11

input = img, v, h
output = SeparableConvolution.apply(*input)

gt = 389 * Variable(torch.ones(1, 1, 1, 1)).cuda()
loss_Lp = torch.nn.MSELoss()
loss = loss_Lp(output, gt)

loss.backward()
print(img.grad)
print(v.grad)
print(img.grad)

‘img.grad’, ‘v.grad’ and ‘h.grad’ are None. I cannot find the error in my code. Could anyone help me with this issue?

Best

Does changing your code from

x = Variable(torch.zeros(...), requires_grad=True).cuda()

to

x = Variable(torch.zeros(...).cuda(), requires_grad=True)

help?

3 Likes

Yes, it works and now I can get gradients. Thank you so much.

Yeah it looks like what’s happening is that x = Variable(torch.zeros(…), requires_grad=True).cuda() creates an intermediate Variable y = Variable(torch.zeros(...), requires_grad=True) and then assigns x = y.cuda().

Since y is the leaf node, the gradients only accumulate in y and not x.

6 Likes

Thank you so much for the explanation.

From what I understand, the difference between Variable(Tensor.cuda()) and Variable(Tensor).cuda() will only come into play when my training task requires the gradient with respect to the inputs. Otherwise, for training tasks like object recognition on ImageNet, both these declarations won’t make any difference as because the task doesn’t need the gradients with respect to the inputs (images in this case). Is my understanding correct?

The reason I ask this because I have seen both these declarations being used in example code and repos.

If you don’t require grad, then yes, those two are equivalent.

1 Like

What would be an example of when you would need gradients with respect to your inputs?

One reason I have encountered is visualization purposes, see pytorch-cnn-visualizations. More specifically: In a module hook

It’s basic question.
Why you call ‘y’ as a leaf node??

Is it related to tree structure in ‘Variable()’??

Sorry for basic question.

1 Like