Gradient of gradient is different from numerical gradient of gradient

Hi, I’m trying to calculate the gradient of the gradient
I could obtain it by using .backward(create_graph=True).
In order to verify it, I wrote the following code and compare the numerical gradient and analytical gradient.
But I found those are quite different and I need to find out the reason.
Seems like I calculated it wrong and the one that PyTorch calculated is correct but
could you let me know what is wrong?

Point : When you see the output, nemerical_score_grad is correct but nemerical_target_feature_map_grad_sum_grad is different from what you can see in the output of register_hook

import os
import argparse
import cv2
import numpy as np
import torch
from torch.autograd import Function
from torchvision import models


def check_grad(grad):
    print('########################### register_hook ##############################')
    print('analytical grad :')
    if grad is not None:
        print('shape : {}'.format(grad.shape))
        print(grad[:5,0,0,0])
    else:
        print('{}'.format(grad))
    print('########################################################################')


if __name__ == '__main__':
    use_cuda = True
    inputs = [torch.randn(1, 3, 224, 224).type(torch.float64)] * 6
    labels = [243] * 6
    
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    delta = 1e-8
    model = models.resnet50(pretrained=True).double().requires_grad_(True)
    model = model.cuda() if use_cuda else model
    
    # target_layer_names is the layer that I'm gonna get a feature map from
    # test_layer_names is the layer that I'm gonna calculate the grad of grad
    target_layer_names='layer3'
    test_layer_names='layer4'
    
    # model.layer4[0].conv1.weight is the weight I'm going to calculate the grad of grad
    handle = model.layer4[0].conv1.weight.register_hook(check_grad)
    
    # save original weight
    ori_weight = model.layer4[0].conv1.weight.clone().detach()
    ori_weight = ori_weight.cuda() if use_cuda else ori_weight
    
    # These are the things I'm going to monitor
    input_list = []
    weight_list = []
    target_grad_list = []
    score_list = []
    target_feature_map_grad_sum_list = []

    
    
    for num, (input, label) in enumerate(zip(inputs, labels)):
        print('loop start')
        
        # skip perturbing a weight for the first iteration 
        if num != 0:
            # copy the save weight to the layer
            model.layer4[0].conv1.weight.data = ori_weight.clone().detach()
            # add delta to a single weight to calculate numerical gradient
            model.layer4[0].conv1.weight.data[num-1,0,0,0] += delta
            
        # save weight
        weight_list.append(model.layer4[0].conv1.weight.data.cpu().detach().numpy()[:,:,0,0])
        
        input = input.clone().detach().requires_grad_(True)
        input = input.cuda() if use_cuda else input
        label = torch.tensor([label]).cuda() if use_cuda else torch.tensor([label])
        
        # Perform forward propagation manually
        output = input
        # save input
        input_list.append(input.cpu().detach().numpy()) ####
        target_feature = None
        test_feature = None
        for name, module in model._modules.items():
            if "avgpool" in name.lower():
                output = module(output)
                output = output.view(output.size(0),-1)
            else:
                output = module(output)
                if name == target_layer_names:
                    target_feature = output
        
        # Use .retain_grad() to get gradient of the feature map later
        target_feature.retain_grad()

        one_hot = output[0, label] # output shape : (1, 1000)
        # save the score
        score_list.append(one_hot.cpu().detach().numpy()[0])

        ################################################################################
        model.zero_grad()
        target_feature.grad = None
        
        # check the initial gradient
        if model.layer4[0].conv1.weight.grad is not None:
            print('model.layer4[0].conv1.weight.grad :', model.layer4[0].conv1.weight.grad[:5,0,0,0])
        else:
            print('model.layer4[0].conv1.weight.grad :', model.layer4[0].conv1.weight.grad)
#         target_grad = torch.autograd.grad(one_hot, target_feature, retain_graph=True, create_graph=True)
        
        # calculate gradient
        one_hot.backward(retain_graph=True, create_graph=True)
        
        # save gradient map of the feature map
        target_grad_list.append(target_feature.grad.cpu().detach().numpy()) # (1, 1024, 14, 14)
        
        model.zero_grad()
        # sum all gradients of the gradient map
        target_grad_sum = target_feature.grad.sum()
        # save the summed gradient map
        target_feature_map_grad_sum_list.append(target_grad_sum.cpu().detach().numpy())
        # calculate the gradient of gradient
        target_grad_sum.backward()
        ################################################################################
        
        if num == 0:
            handle.remove()
        
        print()
    
    print('######################## loop ended now check out results ########################')
    
    # See the result
    # input_list[0].shape : (1, 3, 224, 224)
    input_diff = np.mean(np.concatenate(input_list[1:], axis=0) - input_list[0], axis=(1, 2, 3))
    print('input_diff :\n', input_diff)
    print()
    
    # weight_list[0].shape : (1, 512, 1024)
    weight_diff = []
    for weight_ele in weight_list[1:]:
        weight_diff.append((weight_ele - weight_list[0])[:10,0])
    print('weight_diff :\n', weight_diff)
    print()
        
    # target_grad_list[0] : (1, 1024, 14, 14)
    target_feature_map_grad_diff_sum = []
    for tg_ele in target_grad_list[1:]:
        target_feature_map_grad_diff_sum.append((tg_ele - target_grad_list[0]).sum())
    print('target_feature_map_grad_diff_sum :\n', target_feature_map_grad_diff_sum)
    print()
    
    # score_array.shape : (2,)
    score_array = np.array(score_list)
    nemerical_score_grad = (score_array[1:] - score_array[0]) / delta
    print('nemerical_score_grad :\n', nemerical_score_grad)
    print()
    
    # target_feature_map_grad_sum_array.shape : (2,)
    target_feature_map_grad_sum_array = np.array(target_feature_map_grad_sum_list)
    nemerical_target_feature_map_grad_sum_grad = (target_feature_map_grad_sum_array[1:] - target_feature_map_grad_sum_array[0]) / delta
    print('nemerical_target_feature_map_grad_sum_grad :\n', nemerical_target_feature_map_grad_sum_grad)
    print()

output

loop start
model.layer4[0].conv1.weight.grad : None
########################### register_hook ##############################
analytical grad :
shape : torch.Size([512, 1024, 1, 1])
tensor([0.0055, 0.0113, 0.0097, 0.0040, 0.0040], device=‘cuda:0’,
dtype=torch.float64, grad_fn=)
########################################################################
########################### register_hook ##############################
analytical grad :
shape : torch.Size([512, 1024, 1, 1])
tensor([-5.7899e-17, 1.3624e-16, -1.5690e-16, -1.4070e-16, 2.3213e-16],
device=‘cuda:0’, dtype=torch.float64)
########################################################################

loop start
model.layer4[0].conv1.weight.grad : tensor([0., 0., 0., 0., 0.], device=‘cuda:0’, dtype=torch.float64)

loop start
model.layer4[0].conv1.weight.grad : tensor([0., 0., 0., 0., 0.], device=‘cuda:0’, dtype=torch.float64)

loop start
model.layer4[0].conv1.weight.grad : tensor([0., 0., 0., 0., 0.], device=‘cuda:0’, dtype=torch.float64)

loop start
model.layer4[0].conv1.weight.grad : tensor([0., 0., 0., 0., 0.], device=‘cuda:0’, dtype=torch.float64)

loop start
model.layer4[0].conv1.weight.grad : tensor([0., 0., 0., 0., 0.], device=‘cuda:0’, dtype=torch.float64)

######################## loop ended now check out results ########################
input_diff :
[0. 0. 0. 0. 0.]

weight_diff :
[array([1.e-08, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00,
0.e+00, 0.e+00]), array([0.e+00, 1.e-08, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00,
0.e+00, 0.e+00]), array([0.e+00, 0.e+00, 1.e-08, 0.e+00, 0.e+00, 0.e+00, 0.e+00, 0.e+00,
0.e+00, 0.e+00]), array([0.e+00, 0.e+00, 0.e+00, 1.e-08, 0.e+00, 0.e+00, 0.e+00, 0.e+00,
0.e+00, 0.e+00]), array([0.e+00, 0.e+00, 0.e+00, 0.e+00, 1.e-08, 0.e+00, 0.e+00, 0.e+00,
0.e+00, 0.e+00])]

target_feature_map_grad_diff_sum :
[2.2546648568304067e-15, 3.220287531985217e-15, -6.699349743875711e-16, -6.343048378637844e-16, -3.324264843049737e-16]

nemerical_score_grad :
[0.00550495 0.01130982 0.0096729 0.00400663 0.00398234]

nemerical_target_feature_map_grad_sum_grad :
[2.62290190e-07 3.63598041e-07 6.38378239e-08 1.66533454e-08
3.88578059e-08]

Hi,

Do I understand correctly that the difference between the hessian value that you computed by hand compared to the one computed with pytorch is in the order of 1e-2 ?

Your code is fairly long (and uses local files) so it is not super easy to reproduce.

Hi,
I’m so sorry that I didn’t explain it well.
Also, I changed my code so that you can execute it without any additional files.
So what I’m doing in the code is that

  1. I put a randomized image into resnet50
    (layers of resnet50 is as following)

conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc

  1. I picked the output of layer3 and get the gradient map of it. (gradient of a single score with respect to the output of layer3)
  2. and I summed all the elements of the gradient map. Let’s call this as grad_sum_loss
  3. And I get the gradient of grad_sum_loss with respect to “layer4[0].conv1.weight”. Note that it’s layer4 not layer3.

So when you see the output, there are two register_hook results.
Both of them are from the register_hook of the weight of “layer4[0].conv1.weight” but each of them are from different .backward().
The register_hook first called when the gradient of a single score with respect to the output of layer3 is calculated.
The register_hook secondly called when the gradient of grad_sum_loss with respect to “layer4[0].conv1.weight” is calculated.

I used for loop to calculate numerical gradient of each weights of “layer4[0].conv1.weight”.
You can see nemerical_target_feature_map_grad_sum_grad in the output which is the numerical gradient of grad_sum_loss with respect to “layer4[0].conv1.weight”.
The problem is that this is different from the second register_hook result.
(I also calculated the numerical gradient of a single score with respect to “layer4[0].conv1.weight” but this is identical to the first register_hook result)

So you mean that you compare the second hook that prints tensor([-5.7899e-17, 1.3624e-16, -1.5690e-16, -1.4070e-16, 2.3213e-16], device=‘cuda:0’, dtype=torch.float64) vs [2.62290190e-07 3.63598041e-07 6.38378239e-08 1.66533454e-08 3.88578059e-08] ?

The difference looks small enough that this is numerical errors no?