Unable to do autograd way back to the input

Hi,
I am trying to perform autograd in the following example, but unable to calculate gradients way back to the theta variable.

import torch
from torch.autograd import Variable
    theta = Variable(torch.randn(1)*3.1414)
    theta.requires_grad_(True)
    cosval = Variable(torch.cos(theta))
    cosval.requires_grad_(True)
    mat = Variable(torch.Tensor([1,cosval,cosval]))
    mat.requires_grad_(True)
    inpt = torch.Tensor([3,3,3])
    y = torch.dot(mat,inpt)
    y.backward(torch.ones(3))

when I ran the above code I am able to see the gradients of ‘mat’ but not for ‘cosval’ and ‘theta’. Here gradients for cosval and theta are none.

Hi @harishkool,

Here the graph built by autograd will represent only " y = mat . inpt".

That’s because mat = Variable(torch.Tensor([1,cosval,cosval])) is a leaf node (think that mat creation torch.Tensor([1,cosval,cosval] is not a differentiable operation).

Also Variable is deprecated and you should use Tensor instead.
If you replace them by Tensor and use torch.cat to create mat:

mat = torch.cat((torch.Tensor([1]), cosval, cosval))

it should do.

You will also need to call retain_grad on Tensors you want to access the gradient after the backward pass. So you should end up with something like:

import torch

theta = torch.randn(1, requires_grad=True)*3.1414
theta.retain_grad()

cosval = torch.cos(theta)
cosval.retain_grad()

mat = torch.cat((torch.Tensor([1]), cosval, cosval))
mat.retain_grad()

inpt = torch.Tensor([3,3,3])
y = torch.dot(mat,inpt)
y.backward(torch.ones(3))
1 Like

Thank you. Your solution solved my problem. But I am just wondering how come torch.Tensor([1,cosval,cosval]) is not differentiable and torch.cat((torch.Tensor([1]),cosval,cosval)) is differentiable?. Is it because torch.Tensor() is a leaf node where as torch.cat() is a torch op and makes it as a part of the graph?

Np, happy to read that! :slight_smile:
That’s correct, torch.Tensor() is not an op but will be and will be a leaf node in your graph, but torch.cat is one.

Got it. Thank you Serge.!!