The result obtained after an operation from a tensor of a pyTorch model does not have grad_fn

The result obtained after an operation from a tensor obtained from pyTorch model does not have grad_fn and can not backward. I use .clone().detach() to the input and cast tensors of the compute_grad function shown below, So they should be disassociated from the previous model. However, the grad_fn of y is None, and the dy.sum().backward can not run.
Here is my code.

# All necessary imports at the beginning
import torch
import torchvision
import torch.nn as nn
from torch import Tensor
from torch.autograd import Variable
from torch.utils.data import DataLoader

# here is a custom grad_weight_computation function, the 'input' is the input of a specific layer in a pytorch model
# and the 'layer' is exactly the layer, the 'cast' is the ratio of grad_output to output, the grad_output is the
# partial derivative of loss wrt to output, the output is the output of the specific layer, in my case, the layer
# is always nn.Conv2d or nn.Linear
def compute_grad(input: Tensor, layer: nn.Linear, cast : Tensor):
    x = input.clone().detach()
    cast = cast.clone().detach()
    x.requires_grad = True
    
    # Create an identical linear layer without backward
    func = nn.Linear(in_features=layer.in_features, out_features=layer.out_features, bias=True)
    # Import parameter dictionary
    func.load_state_dict(layer.state_dict())
    
    y = func(x)
    # The anomaly is: the grad_fn of y is None, but the attribute of is_leaf is True
    print(y.grad_fn)
    # out:None
    print(y.is_leaf)
    # out:True
    dy = y * cast
    dy.sum().backward()
    dw = func.weight.grad
    print(dw)
    return dw

def register_for_hook(model):
    for _, i in model.named_modules():
        if isinstance(i, nn.Linear):  
            i.register_forward_hook(forward_hook)
            i.register_backward_hook(backward_hook) 

def forward_hook(module, input, output):
    # the forward_hook is used to record the inputs, outputs and layers of all linear layers
    # in a pytorch model during running.
    # In order to reduce the amount of code, there is no need to list in detail
    return 
    
def backward_hook(module, grad_in, grad_out):
    # the backward_hook is used to 
    # 1. Obtain the input, output and layer that we recorded in the forward_hook
    # 2. Since the input and output were encrypted, so we decrypt them
    # 3. Pass the input, layer and gout/output to our compute_grad func to compute real grad
    # new_dw = compute_grad(input=input, layer=module, cast=grad_out/output)
    # return new_dw to replace the original dw
    return 

Epoch=3
Batch_Size=50
LR=0.0001
num_classes=10
net = torchvision.models.vgg16(pretrained=True)
# register hook for the net
register_for_hook(net)

# datasets
trainData=torchvision.datasets.MNIST(
    root="/home/lpx/codes/hook/data",
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True)
 
train_loader=DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="home/lpx/codes/hook/data",train=False,download=True)

def Train(model):
    # Run the regular training process, the loss will go throught the backward process
    # and trigger the backward hook.
    # loss.backward()
    # In order to reduce the amount of code, there is no need to list in detail
    return
Train(net)

The error message says

'''
Traceback (most recent call last):
  File "vgg.py", line 159, in <module>
    Train(net)
  File "vgg.py", line 158, in Train
    loss.backward()
  File "/home/lpx/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/lpx/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
  File "vgg.py", line 132, in backward_hook
    new_dw = compute_grad(input, layer, cast)
  File "vgg.py", line 114, in self_grad
    dy.sum().backward()
  File "/home/lpx/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/lpx/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
'''

And by contrast, if the input, layer and output are not obtained from a running pytorch model, instead they are created by users, the the same codes can run without error. You can see below:

import torch
from torch import nn, Tensor
from torch.autograd import Variable

def compute_grad(input: Tensor, layer, cast):
    x = input.clone().detach()
    cast = cast.clone().detach()
    x.requires_grad = True
    func = nn.Linear(in_features=layer.in_features, out_features=layer.out_features, bias=True)
    func.load_state_dict(layer.state_dict())
    
    y = func(x)
    print(y.grad_fn)
    # out: <AddmmBackward object at 0x7f5028a50fa0>
    y = y.requires_grad_()
    dy = y * cast
    dy = dy.requires_grad_()

    dy.sum().backward()
    dw = func.weight.grad.data
    return dw

input = torch.rand([50, 512], requires_grad=True)
m = nn.Linear(in_features=512, out_features=10, bias=True)
output = m(input)
grad_out = torch.randn_like(output)
# Simulate the real model running process through backward
output.sum().backward()

h = grad_out / output
new_dw = compute_grad(input, m, h)
print(new_dw)
# out: Tensor

The background of this problem is that we want to be able to encrypt the intermediate value when running a pytorch model. I think it’s very valuable. So if you can help me with this suspected bug I would be very grateful.
I have tried many versions of pytorch but got the same problem.

In your code you are not running anything as the Train method is only containing a comment.
Using:

input = torch.randn(1, 1)
layer = nn.Linear(1, 1)
cast = torch.randn(1, 1)

compute_grad(input, layer, cast)
# <AddmmBackward0 object at 0x7f2424679af0>
# False
# tensor([[-0.3559]])

works for me and shows a valid .grad_fn.

Thanks for your quick reply!
The reason why the training method only contains comments is because it is just a common model training process and has nothing to do with the essence of this problem. In order to reduce the amount of code, I did not list the training code, but it actually ran in my code.

If you look closely to my code after And by contrast…, you will find it similar to your code. The real problem is that if input, layer and cast are obtained from a running model but not created by us using torch.rand().

Hope you understand! And I sincerely appreciate it if you can have a look again.

I cannot try to speculate how you might be executing the code to recreate the issue, so please feel free to post a minimal, executable code snippet reproducing the error.

Thanks for your quick reply! Here is my executable code which produces the unexpected error:

import torch
import torchvision
import torch.nn as nn
from typing import Union
from torch import Tensor
from torch.autograd import Variable, variable
from torch.utils.data import DataLoader
VGG_types = {
"VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"VGG16": [64,64,"M",128,128,"M",256,256,256,"M",512,512,512,512,512,512],
"VGG19": [64,64,"M",128,128,"M",256,256,256,256,"M",512,512,512,512,
          "M",512,512,512,512,"M",],}

# this class aims to record down all the inputs, outputs and layers of a running model
class trace_layers():
    def __init__(self) -> None:
        self.inputs = []
        self.outputs = []
        self.layers = []
        self.n = 0
        pass
    
    def update(self, model, input, output):
        self.n += 1
        self.layers.append(model)
        self.inputs.append(input)
        self.outputs.append(output)

        
    def get_cur(self):
        layer = self.layers[self.n - 1]
        x = self.inputs[self.n - 1]
        y = self.outputs[self.n - 1]
        self.n -= 1
        return layer, x, y

layer_list = trace_layers()

# build model
VGGType = "VGG16"
class VGGnet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(VGGnet, self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self.create_conv_layers(VGG_types[VGGType])
        self.fcs = nn.Sequential(
        nn.Linear(512 * 3 * 3, 512),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fcs(x)
        return x

    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == int:
                out_channels = x
                layers += [
                nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=(1, 1),
                ),
                nn.BatchNorm2d(x),
                nn.ReLU(),
                ]
                in_channels = x
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]

        return nn.Sequential(*layers)

# determine whether it is a linear layer
def isLinear(feature):
    if isinstance(feature, nn.Linear) or isinstance(feature, nn.Conv2d):
        return True
    else:
        return False
    
def register_for_hook(model):
    for _, i in model.named_modules():
        if not isinstance(i, nn.Sequential):
            if isLinear(i):  
                i.register_forward_hook(forward_hook)
                i.register_backward_hook(backward_hook) 

def compute_grad(input: Tensor, layer, cast : Tensor):
    x = input.clone().detach()
    cast = cast.clone().detach()
    x.requires_grad = True
    func = nn.Linear(in_features=layer.in_features, out_features=layer.out_features, bias=True)
    func.load_state_dict(layer.state_dict())
    
    y = func(x)
    print('grad_fn of y:{}'.format(y.grad_fn))
    y = y.requires_grad_()
    dy = y * cast
    dy = dy.requires_grad_()

    dy.sum().backward()
    dw = func.weight.grad.data
    return dw

def forward_hook(module, input, output):
    layer_list.update(module, input, output)
    return None
    
def backward_hook(module, gin, gout):
    layer, input, output = layer_list.get_cur()
    input = input[0]
    dy = gout[0]
    dw: Tensor
    if isinstance(module, nn.Linear):
        db, dw, dx = gin
    elif isinstance(module, nn.Conv2d):
        dw, dx, db = gin
    cast = dy / output
    # call my custom compute_grad function
    self_dw = compute_grad(input, layer, cast)
    # the ultimate purpose is to find out whether self_dw and dw are identically same
    print(self_dw==dw)
    return None

# initializing
device = 'cpu' if torch.cuda.is_available() else 'cpu'
print('device =', device)
Epoch=3
Batch_Size=50
LR=0.0001
num_classes=10
net=VGGnet(in_channels=1, num_classes=num_classes).to(device)
register_for_hook(net)
trainData=torchvision.datasets.MNIST(
    root="/home/lpx/codes/hook/data",
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True)
 
train_loader=DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="home/lpx/codes/hook/data",train=False,download=True)

def Train(model):
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    for epoch in range(Epoch):
        for step,(b_x,b_y)in enumerate(train_loader):
            b_x = b_x.to(device)
            b_y = b_y.to(device)
            output = model(b_x)
            loss=loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())
            layer_list.inputs.clear()
            layer_list.outputs.clear()
            layer_list.layers.clear()
            layer_list.n = 0
            
    print('res finish training')

Train(net)

Thanks again for your time.

Thanks for the updated code. Autgrad is disabled in backward hooks, so you might want to enable it via with torch.enable_grad():. Once this is done you will run into shape mismatch errors:

    print(self_dw==dw)

RuntimeError: The size of tensor a (10) must match the size of tensor b (50) at non-singleton dimension 0

and if you are removing this print into the wrong usage of an unknown attribute:

AttributeError: 'Conv2d' object has no attribute 'in_features'

Thanks for your perfect answer! I am already aware of the reason. :grinning:

Hi ptrblck,

Could I ask one more question? The previous cast was: dy / output , where dy is d(loss)/dy and y is the output of the layer. But I found that the self_dw calculated by the compute_grad function is not the same as the dw I got from backward_hook. Is there any way for me to build some connections and calculate the correct derivative of loss to weight through the automatic derivation mechanism?

Thanks very much!

I don’t fully understand your approach since you are trying to unpack the grad_input into db, dw, dx which sounds wrong.
Here is a very small code example showing how to calculate dw = grad_out * input:

def compute_grad(input, layer, cast):
    x = input.clone().detach()
    cast = cast.clone().detach()
    func = nn.Linear(in_features=layer.in_features, out_features=layer.out_features, bias=False)
    func.load_state_dict(layer.state_dict())
    
    y = func(x)
    print('grad_fn of y:{}'.format(y.grad_fn))
    dy = y

    dy.sum().backward()
    dw = func.weight.grad.clone()
    return dw

    
def backward_hook(module, gin, gout):
    with torch.enable_grad():
        dy = gout[0]
        if isinstance(module, nn.Linear):
            dw = torch.matmul(dy.T, x)
        cast = out
        # call my custom compute_grad function
        self_dw = compute_grad(x, module, cast)
        # the ultimate purpose is to find out whether self_dw and dw are identically same
        print(self_dw==dw)
        print(self_dw - dw)
        return None


lin = nn.Linear(10, 10, bias=False)
lin.register_full_backward_hook(backward_hook)

x = torch.randn(1, 10)
out = lin(x)
out.sum().backward()

Sorry to bother again,
I just found that I should unpack the grad_input into db, dx, dw if it’s nn.Linear or dx, dw, db if it’s nn.Conv2d . Is that right?

By the way, In this code example, the lin is the last layer, so the derivation of the result to y is full-ones, but what if there is another layer after the lin module? Is it still right?

I try to run it in my model, and I found the dw obtained by torch.matmul(dy.T, x) and the dw obtained by function compute_grad are not the same.

No, I don’t think this is right as previously mentioned.
The grad_input tensor contains the gradient for the input to the method as seen here:

class MySum(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x.sum(1)
    
module = MySum()
x = torch.randn(1, 10, requires_grad=True)

module.register_full_backward_hook(lambda m, grad_input, grad_output: print(grad_input, grad_output))

out = module(x)
out.backward()
# (tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),) (tensor([1.]),)
print(x.grad)
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

It does not contain the gradient w.r.t. the weight or bias.

Thanks for your time!

However, when I ran my code, I found that the grad_input really contains the gradient wrt to the weight and bias, please have a look at this code sample, from line 134-141 in backward_hook, you can find the grad_input actually prints out the three gradients mentioned above:

import torch
import torchvision
import torch.nn as nn
from typing import Union
from torch import Tensor
from torch.autograd import Variable, variable
from torch.utils.data import DataLoader
VGG_types = {
"VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
"VGG16": [64,64,"M",128,128,"M",256,256,256,"M",512,512,512,512,512,512],
"VGG19": [64,64,"M",128,128,"M",256,256,256,256,"M",512,512,512,512,
          "M",512,512,512,512,"M",],}

# this class aims to record down all the inputs, outputs and layers of a running model
class trace_layers():
    def __init__(self) -> None:
        self.inputs = []
        self.outputs = []
        self.layers = []
        self.n = 0
        pass
    
    def update(self, model, input, output):
        self.n += 1
        self.layers.append(model)
        self.inputs.append(input)
        self.outputs.append(output)

        
    def get_cur(self):
        layer = self.layers[self.n - 1]
        x = self.inputs[self.n - 1]
        y = self.outputs[self.n - 1]
        self.n -= 1
        return layer, x, y

layer_list = trace_layers()

# build model
VGGType = "VGG16"
class VGGnet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(VGGnet, self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self.create_conv_layers(VGG_types[VGGType])
        self.fcs = nn.Sequential(
        nn.Linear(512 * 3 * 3, 512),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fcs(x)
        return x

    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == int:
                out_channels = x
                layers += [
                nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=(1, 1),
                ),
                nn.BatchNorm2d(x),
                nn.ReLU(),
                ]
                in_channels = x
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]

        return nn.Sequential(*layers)

# determine whether it is a linear layer
def isLinear(feature):
    if isinstance(feature, nn.Linear) or isinstance(feature, nn.Conv2d):
        return True
    else:
        return False
    
def register_for_hook(model):
    for _, i in model.named_modules():
        if not isinstance(i, nn.Sequential):
            if isLinear(i):  
                i.register_forward_hook(forward_hook)
                i.register_backward_hook(backward_hook) 

def compute_grad(input: Tensor, layer, cast : Tensor):
    x = input.clone().detach()
    cast = cast.clone().detach()
    if isinstance(layer, nn.Linear):
        func = nn.Linear(in_features=layer.in_features, out_features=layer.out_features, bias=True)
    else:
        layer: nn.Conv2d
        func = nn.Conv2d(in_channels=layer.in_channels, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride\
            ,padding=layer.padding)
    func.load_state_dict(layer.state_dict())
    
    y = func(x)
    dy = y * cast
    dy.sum().backward()
    dw = func.weight.grad
    if dw is not None:
        dw = dw.clone()
        return dw

def forward_hook(module, input, output):
    layer_list.update(module, input, output)
    return None
    
def backward_hook(module, gin, gout):
    layer, input, output = layer_list.get_cur()
    input = input[0]
    dy = gout[0]
    if isinstance(module, nn.Linear):
        db, dx, dw = gin
    elif isinstance(module, nn.Conv2d):
        dx, dw, db = gin
    cast = dy / output
    # call my custom compute_grad function
    print(len(gin))
    # 3
    print(input.shape, module.bias.shape, module.weight.shape)
    # torch.Size([50, 512]) torch.Size([10]) torch.Size([10, 512])
    print(dx.shape, db.shape, dw.shape)
    # torch.Size([50, 512]) torch.Size([10]) torch.Size([512, 10])
    # the dw.shape is the transpose of module.weight.shape, the pytorch will transpose the
    # shape of weight during actual computation
    with torch.enable_grad():
        self_dw = compute_grad(input, module, cast)
    assert self_dw == dw

# initializing
device = 'cpu' if torch.cuda.is_available() else 'cpu'
print('device =', device)
Epoch=3
Batch_Size=50
LR=0.0001
num_classes=10
net=VGGnet(in_channels=1, num_classes=num_classes).to(device)
register_for_hook(net)
trainData=torchvision.datasets.MNIST(
    root="/home/lpx/codes/hook/data",
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True)
 
train_loader=DataLoader(dataset=trainData,batch_size=Batch_Size,shuffle=True)
test_data=torchvision.datasets.MNIST(root="home/lpx/codes/hook/data",train=False,download=True)

def Train(model):
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    for epoch in range(Epoch):
        for step,(b_x,b_y)in enumerate(train_loader):
            b_x = b_x.to(device)
            b_y = b_y.to(device)
            output = model(b_x)
            loss=loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())
            layer_list.inputs.clear()
            layer_list.outputs.clear()
            layer_list.layers.clear()
            layer_list.n = 0
            
    print('res finish training')

Train(net)

I don’t know what the reason is, is it because your code sample doesn’t have an optimizer? Or is it because the difference between register_full_backward_hook and register_backward_hook?

I believe I have found the solution to my problem, thank you!