Why is the gradient w.r.t. to an input coming out different from w.r.t. a clone of that input

I am trying feature visualization, and noticed a curious phenomenon. when i compare the norm of the gradient w.r.t to a input with the gradient w.r.t. a clone of this, i am getting different answers. I have not been able to figure this out. I am pasting the google colab code here:

imports etc., initial code:

!pip install --no-cache-dir -I pillow
# http://pytorch.org/
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
import torch
import torchvision
import numpy as np
from matplotlib import pyplot as plt
import time
import pdb
!git clone http://github.com/tumble-weed/images
import os
os.listdir('images')
from skimage import io
from PIL import Image

The meat of the code:

im = io.imread('images/ILSVRC2012_val_00000013.JPEG')
model = torchvision.models.vgg16(pretrained=True)
model.eval()

s = 224
mean = [0.5,0.5,0.5]
std = [0.225,0.225,0.225]
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((s,s)),
                                          torchvision.transforms.ToTensor(),
                                          torchvision.transforms.Normalize(mean=mean,std=std)])
ref = transform(Image.fromarray(im))

class fwdHook():
  def __init__(self):
    self.feat = None
    def hook(obj,input,output):
      self.feat = output
      pass
    self.hook = hook
    pass
  
model_layers = [l for n in model.children() for l in n]
hooks = [fwdHook() for l in model_layers]
hooked_layers = [l.register_forward_hook(h.hook) for l,h in zip(model_layers,hooks)]
if False: print(hooks[0].feat)
model(ref.unsqueeze(0))
if False: print(hooks[0].feat)
  
mag = 10
x_ = mag*np.random.randn(*ref.unsqueeze(0).shape).astype(np.float32)
x = torch.from_numpy(x_)
x = torch.autograd.Variable(x,requires_grad = True)
lidx = 3
model(ref.unsqueeze(0))
ref_feat = hooks[lidx].feat
model(x)
x_feat = hooks[lidx].feat



get_dist_from_ref = lambda feat:torch.sum((ref_feat - feat)**2)/torch.sum((ref_feat)**2)
loss = get_dist_from_ref(x_feat)
print(loss)

# loss.zero_grad()
loss.backward(retain_graph=True)
x_grad = x.grad.clone()
print(torch.norm(x_grad))
im_x_grad = x_grad.permute(0,2,3,1)[0]
im_x_grad = im_x_grad.detach().cpu().data.numpy()

# print(x_grad)

'''----------------------'''
xx = x.clone()
xx = torch.autograd.Variable(xx,requires_grad = True)
model.forward(xx)
xx_feat = hooks[lidx].feat
loss_xx_ref = torch.dist(ref_feat,xx_feat)/torch.norm(ref_feat.view(-1))
print(f'loss_x2_ref {loss_xx_ref}')
loss_xx_ref.backward()
xx_grad = xx.grad.clone()
print(xx_grad.norm())

Output:

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.torch/models/vgg16-397923af.pth
100%|██████████| 553433881/553433881 [00:05<00:00, 97564919.92it/s] 
tensor(543.5735, grad_fn=<DivBackward1>)
tensor(0.4398)
loss_x2_ref 23.292146682739258
tensor(0.0094)

If you make your example more minimal, it will be easier to see. Also, I’d recommend against using Variable (and against using torch versions where you have to).
Then, xx = x.clone() will give you xx that is connected to x for the backward, i.e. losses calculated from xx will backward into x, too, whereas losses calculated from x will only show in x. (This is cumulative, i.e. all gradients from backwards are added.)
On the other hand, xx = x.detach().clone().requires_grad_() will give you something that is completely separate.

Best regards

Thomas

thanks

  1. I thought it would be better to give executable code, will give snippet next time.
  2. my problem seems to have gotten solved after doing x.grad.data.zero_(), it seems that the gradients were accruing up into the variable of interest. Will do some more tests to verify.
  3. Thanks for the info on clone, it was causing short circuit of gradients. a question is why allow the clone to affect the original gradients? if someone wanted x to be connected to it, they would have used x, and not a clone of it. I was thinking it was just a numeric copy of x, but otherwise disconnected from it, like in Theano.

Ideally, you would come up with a minimal runnable example. If you look at the time economics of a forum like this, when we all do that it helps all of us get better answers to our questions, because the “answer time” is a relatively scarce resource.

Usually you would call opt.zero_grad() or model.zero_grad().

Because people want that and use .detach() when they don’t: .clone() means “different memory” but connected in autograd, .detach() means “disconnect in autograd”, but same memory.

Best regards

Thomas