Gradient w.r.t. input image is None

Hi everybody,

I am trying to return the gradient of an output score w.r.t. the input image. While I am setting requires_grad()=True, the gradient is None. Here are two equivalent codes to get the gradients, Code1 raises the error RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior, and Code2 returns None. Does anybody have some ideas about what I missed?

Code1:
net_d.eval()
im = images[0, …].reshape(1, images.shape[1], images.shape[2], images.shape[3]).clone().cuda()
im.requires_grad_()
d_x = net_d(im)
d_x = torch.max(torch.sigmoid(d_x))
d_x.requires_grad_()
grad_1 = torch.autograd.grad(d_x, im, create_graph=True)
print(grad_1)

Code2:
net_d.eval()
im = images[0, …].reshape(1, images.shape[1], images.shape[2], images.shape[3]).clone().cuda()
im.requires_grad_()
d_x = net_d(im)
d_x = torch.max(torch.sigmoid(d_x))
out = Variable(d_x, requires_grad=True)
out.backward(retain_graph=True)
print(im.grad)

Your first approach works fine and you wouldn’t need to call .requires_grad_() on the output:

im = torch.randn(1, 1, device='cuda')
im.requires_grad_()
net_d = nn.Linear(1, 1).cuda()

d_x = net_d(im)
d_x = torch.max(torch.sigmoid(d_x))
grad_1 = torch.autograd.grad(d_x, im, create_graph=True)
print(grad_1)
# (tensor([[-0.0331]], device='cuda:0', grad_fn=<TBackward0>),)

assuming you are not detaching the computation graph.

If case you get an error after removing d_x.requires_grad_() then check where the computation graph is detached in your model (e.g by rewrapping an intermediate activation into a new tensor).

Thanks for the quick response @ptrblck!

I loaded the model state_dict a line before net_d.eval():
‘net_d.load_state_dict(torch.load(PATH))’
Removing d_x.requires_grad_() raises error ‘element 0 of tensors does not require grad and does not have a grad_fn’. Shoud I calculate the saliency map from a model which weights are loaded by a saved state_dict weights in another way?

No, loading the state_dict should work and the error points to a detached computation graph.
Could you post the model definition as well as the input shape you are using so that I could reproduce the issue, please?

Sure, the model has been defined on src/model.py on the github page ‘GitHub - ricbl/eye-tracking-localization: This repository contains code for the paper "Localization supervision of chest x-ray classifiers using label-specific eye-tracking annotation".’. The input size is [bsx3x512x512]. The function ‘get_highlight(image, self.opt.get_saliency)’ has been written in a separate script from train.py; Please let me know if you need further information or you need the checkpoint.

import torch
import numpy as np
import torchvision.transforms as transforms # import ToPILImage
from torch.autograd import Variable
import opts
opt = opts.get_opt()
import model
from torch.autograd import Variable

def get_highlight(images, method):
net_d = model.Thoracic(opt.grid_size, pretrained=opt.use_pretrained, calculate_cam=opt.calculate_cam, last_layer_index=opt.last_layer_index).cuda() net_d.load_state_dict(torch.load(PATH))
net_d.eval()
im = images[0, …].reshape(1, images.shape[1], images.shape[2], images.shape[3]).clone().cuda()
im.requires_grad_()
d_x = net_d(im)
d_x = torch.max(torch.sigmoid(d_x))
out = Variable(d_x, requires_grad=True)
out.backward(retain_graph=True, requires_grad=True)
saliency, _ = im.grad.data.abs()

I’m unable to reproduce the issue using:

model = Thoracic().cuda()
bs = 2
x = torch.randn(bs, 3, 512, 512).cuda().requires_grad_()

out = model(x)
loss = out[0].mean() + out[1].mean()

grad_1 = torch.autograd.grad(loss, x, create_graph=True)
print(grad_1)

and get a valid gradient. Could you check what the difference between the codes could be?

Running your code I’m getting this error ‘RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn’.

Did you disable the gradient calculation globally by chance? E.g. via torch.autograd.set_grad_enabled(False)?

I did, I’m sorry for not concidering that. Thanks a lot for your help! :slight_smile: