No grad_fn despite never calling .detach()

For my autoencoder ConvLSTM

import torch.nn as nn
import torch
from ..utils.utils import zeros, mean_cube, last_frame, ENS
from .Conv_LSTM import Conv_LSTM_Cell

class AutoencLSTM(nn.Module):
    """Encoder-Decoder architecture based on ConvLSTM"""
    def __init__(self, input_dim, output_dim, hidden_dims, big_mem, kernel_size, memory_kernel_size, dilation_rate,
                    img_width, img_height, layer_norm_flag=False, baseline="last_frame", num_layers=1, peephole=True):
        super(AutoencLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers

        self.input_dim = input_dim
        self.h_channels = [[], []]                                                     # n of channels in input pics
        self.h_channels[0] = self._extend_for_multilayer(hidden_dims, num_layers)      # n of hidden channels for encoder cells
        self.h_channels[1] = self._extend_for_multilayer(hidden_dims, num_layers - 1)  # n of hidden channels for decoder cells 
        self.h_channels[1].append(output_dim)                                          # n of channels in output pics
        self.big_mem = big_mem                                                         # true means c = h, false c = 1. 
        self.num_layers = num_layers                                                   # n of channels that go through hidden layers
        self.kernel_size = kernel_size     
        self.memory_kernel_size = memory_kernel_size                                   # n kernel size (no magic here)
        self.dilation_rate = dilation_rate
        self.layer_norm_flag = layer_norm_flag
        self.img_width = img_width
        self.img_height = img_height
        self.baseline = baseline
        self.peephole = peephole

        cur_input_dim = [self.input_dim if i == 0 else self.h_channels[0][i - 1] for i in range(self.num_layers)]
        self.ENC = nn.ModuleList([Conv_LSTM_Cell(cur_input_dim[i], self.h_channels[0][i], big_mem, kernel_size, memory_kernel_size, dilation_rate, 
                                                 layer_norm_flag, img_width, img_height, peephole) for i in range(num_layers)])
        self.DEC = nn.ModuleList([Conv_LSTM_Cell(self.h_channels[0][i], self.h_channels[1][i], big_mem, kernel_size, memory_kernel_size, dilation_rate, 
                                                 layer_norm_flag, img_width, img_height, peephole) for i in range(num_layers)])

    def forward(self, input_tensor, non_pred_feat=None, prediction_count=1):
        baseline = eval(self.baseline + "(input_tensor[:, 0:5, :, :, :], 4)")
        b, _, width, height, T = input_tensor.size()
        hs = [[], []]
        cs = [[], []]

        # For encoder and decoder
        for j, part in enumerate([self.ENC, self.DEC]):
            for i in range(self.num_layers):
                h, c = part[i].init_hidden(b, (height, width))
                hs[j].append(h)
                cs[j].append(c)

        pred_deltas = torch.zeros((b, self.h_channels[1][-1], height, width, prediction_count), device = self._get_device())
        preds = torch.zeros((b, self.h_channels[1][-1], height, width, prediction_count), device = self._get_device())
        baselines = torch.zeros((b, self.h_channels[1][-1], height, width, prediction_count), device = self._get_device())

        # iterate over the past
        for t in range(T):
            hs[0][0], cs[0][0] = self.ENC[0](input_tensor=input_tensor[..., t], cur_state=[hs[0][0], cs[0][0]])
            for i in range(1, self.num_layers):
                # encode
                hs[0][i], cs[0][i] = self.ENC[i](input_tensor=hs[0][i - 1], cur_state=[hs[0][i], cs[0][i]])
                # decode
                hs[1][i - 1], cs[1][i - 1] = self.DEC[i - 1](input_tensor=hs[0][i], cur_state=[hs[1][i - 1], cs[1][i - 1]])
        
        baselines[..., 0] = baseline
        pred_deltas[..., 0] = hs[1][-1]
        preds[..., 0] = pred_deltas[..., 0] + baselines[..., 0]

        # add a mask to prediction
        if prediction_count > 1:
            non_pred_feat = torch.cat((torch.zeros((non_pred_feat.shape[0],
                                                    1,
                                                    non_pred_feat.shape[2],
                                                    non_pred_feat.shape[3],
                                                    non_pred_feat.shape[4]), device=non_pred_feat.device), non_pred_feat), dim = 1)

            # iterate over the future
            for t in range(1, prediction_count):
                # glue together with non_pred_data
                prev = torch.cat((preds[..., t - 1], non_pred_feat[..., t - 1]), axis=1)

                hs[0][0], cs[0][0] = self.ENC[0](input_tensor=prev, cur_state=[hs[0][0], cs[0][0]])
                for i in range(1, self.num_layers):
                    # encode
                    hs[0][i], cs[0][i] = self.ENC[i](input_tensor=hs[0][i - 1], cur_state=[hs[0][i], cs[0][i]])
                    # decode
                    hs[1][i - 1], cs[1][i - 1] = self.DEC[i - 1](input_tensor=hs[0][i], cur_state=[hs[1][i - 1], cs[1][i - 1]]) 

                pred_deltas[..., t] = hs[1][-1]

                if self.baseline == "mean_cube":
                    baselines[..., t] = (preds[..., t-1] + (baselines[..., t-1] * (T + t)))/(T + t + 1)
                if self.baseline == "zeros":
                    pass
                else:
                    baselines[..., t]  = preds[..., t-1]

                preds[..., t] = pred_deltas[..., t] + baselines[..., t]

        return preds, pred_deltas, baselines
    
    def _get_device(self):
        return next(self.parameters()).device

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                isinstance(kernel_size, int) or
                # lists are currently not supported for Peephole_Conv_LSTM
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, rep):
        if not isinstance(param, list):
            if rep > 0:
                param = [param] * rep
            else:
                return []
        return param

I get the error

element 0 of tensors does not require grad and does not have a grad_fn

despite never calling .detach() on the output or elsewhere in the code. It does not make sense to me that the gradient cannot pass through preds, which corresponds to the predictions.

Could you post a code snippet to execute your code using random input data, so that we could try to reproduce it, please?

1 Like

Sure! Here is a small dummy version of the code. It yields the same error for me.

from torch import nn
import torch


class Simple_Model(nn.Module):
    def __init__(self):
        super(Simple_Model, self).__init__()
        self.layer = torch.nn.Linear(1, 1)

    def forward(self, input_tensor, cur_state):
        return self.layer(input_tensor), self.layer(input_tensor)

class Foo(nn.Module):
    def __init__(self):
        super(Foo, self).__init__()

    def forward(self):
        input_tensor = torch.zeros(2, 1, 1)
        pred_deltas = torch.zeros((2, 1, 1))
        preds = torch.zeros((2, 1, 1))
        baselines = torch.zeros((2, 1, 1)) 

        model = Simple_Model()

        hs = [[], []]
        cs = [[], []]
        
        for i in range(3):
            hs[0].append(torch.zeros(2, 1))
            hs[1].append(torch.zeros(2, 1))
            cs[0].append(torch.zeros(2, 1))
            cs[1].append(torch.zeros(2, 1))

        # iterate over the past
        for t in range(1):
            hs[0][0], cs[0][0] = model(input_tensor=input_tensor[..., t], cur_state=[hs[0][0], cs[0][0]])
            for i in range(1, 2):
                # encode
                hs[0][i], cs[0][i] = model(input_tensor=hs[0][-1], cur_state=[hs[0][i], cs[0][i]])
                # decode
                hs[1][i - 1], cs[1][i - 1] = model(input_tensor=hs[0][i], cur_state=[hs[1][i - 1], cs[1][i - 1]])
        
        baselines[..., 0] = input_tensor[..., 0]
        pred_deltas[..., 0] = hs[1][-1]
        preds[..., 0] = pred_deltas[..., 0] + baselines[..., 0]

        return preds

def main():
    model = Foo()
    out = model.forward()
    out.backward()

if __name__ == "__main__":
    main()

Thanks for the update. I’m not sure what your code does exactly (e.g. since cur_state is never used), but the error is raised since you are not assigning new tensors to hs for all entries (in particular the last one).
I.e., before the loop hs will be initialized as:

[[tensor([[0.],
          [0.]]),
  tensor([[0.],
          [0.]]),
  tensor([[0.],
          [0.]])],
 [tensor([[0.],
          [0.]]),
  tensor([[0.],
          [0.]]),
  tensor([[0.],
          [0.]])]]

After the loop, some entries will be overwritten as:

[[tensor([[-0.0598],
          [-0.0598]], grad_fn=<AddmmBackward0>),
  tensor([[-0.0598],
          [-0.0598]], grad_fn=<AddmmBackward0>),
  tensor([[0.],
          [0.]])],
 [tensor([[-0.0975],
          [-0.0975]], grad_fn=<AddmmBackward0>),
  tensor([[0.],
          [0.]]),
  tensor([[0.],
          [0.]])]]

To calculate preds you are using:

        baselines[..., 0] = input_tensor[..., 0]
        pred_deltas[..., 0] = hs[1][-1]
        preds[..., 0] = pred_deltas[..., 0] + baselines[..., 0]

Note that input_tensor doesn’t require any gradients or has a grad_fn, so baselines is not differentiable but a constant here.
hs[1][-1] returns:

tensor([[0.],
        [0.]])

which contains the initial zeros and no computed tensor. Thus is is also not attached to the computation graph and preds.backward() will thus fail.

1 Like

Thank you so much! There was a bug (as you say, hs[1][-1] is never overwritten). Thanks to your helpful note, I finally found it.

PS.: As I already wrote, the code snippet I provided you for running is a dummy as the actual code would be too much. Hence it was not supposed to make sense. In the original code, we do of course use cur_state.

Cheers!