Getting Runtime error: element 0 of tensors does not require grad and does not have a grad_fn

Hi there!

I am trying to run a simple CNN2LSTM model and facing this error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn.

The strange part is that the current model is a simpler version of my previous model which worked absolutely fine.
To solve this error, I have tried setting “requires_grad=True” for which I modified my target tensors to float() but that throws another error –

RuntimeError: Expected tensor for argument #1 ‘indices’ to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding).

I am not sure where I am going wrong. Why “requires_grad” is not getting set up by default.
Any help will be appreciated! Thanks in advance!

# -*- coding: utf-8 -*-

import time
import torch
from preprocessing.preprocess_images import preprocess_images

def train(trg_field, model, batch_size, iterator, optimizer, criterion,device, write_file=False):
    model.train()
    epoch_loss = 0

    trg_seqs = open('logs/train_targets.txt', 'w')
    pred_seqs = open('logs/train_predicted.txt', 'w')

    for i, batch in enumerate(iterator):
        # initailize the hidden state
        #h = model.encoder.init_hidden(batch_size)

        # grab the image and preprocess it
        img_names = batch.id
        src = preprocess_images(img_names, 'data/images/')
        # src will be list of image tensors
        # need to pack them to create a single batch tensor
        src = torch.stack(src).to(device)

        #print('train_src_shape:  ', src.shape)

        # target mml
        trg = batch.mml.to(device)

        # setting gradients to zero
        optimizer.zero_grad()
        print(src.requires_grad)
        print(trg.requires_grad)
        output, pred = model(trg_field, src, trg, True, True, 0.5)


        # translating and storing trg and pred sequences in batches
        if write_file:
            batch_size = trg.shape[1]
            for idx in range(batch_size):
                trg_arr = [trg_field.vocab.itos[itrg] for itrg in trg[:,idx]]
                trg_seq = " ".join(trg_arr)
                trg_seqs.write(trg_seq + '\n')

                pred_arr = [trg_field.vocab.itos[ipred] for ipred in pred[:,idx].int()]
                pred_seq = " ".join(pred_arr)
                pred_seqs.write(pred_seq+'\n')

        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)

        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss/len(iterator)

Could you post the model definition and additionally a minimal, executable code snippet to reproduce this issue, please?

This won’t solve the issue. Once the computation graph is detached (which it seems to be now) creating a new tensor with requires_grad=True might get rid of the error message, but would create a new computation graph from this point and all previously used operations would still be detached.

Thanks for the link! Could you post a minimal executable code snippet to execute the training using random data and to reproduce the issue, please?

Here is the code snippet of the model. It is producing the same error with an image input and text output. Please let me know if you anything else.

'''
CODE SNIPPET TO REPLICATE
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
import  random

class Encoder(nn.Module):

    def __init__(self, input_channel, hid_dim, dropout, device):
        super(Encoder, self).__init__()

        self.device = device
        self.scale = torch.sqrt(torch.FloatTensor([0.5])).to(self.device)
        self.hid_dim = hid_dim
        self.conv_layer1 = nn.Conv2d(input_channel, 64, kernel_size=(3,3), stride=(1,1), padding =(1,1))
        self.conv_layer2 = nn.Conv2d(64, 128, kernel_size=(3,3), stride=(1,1), padding =(1,1))
        self.conv_layer3 = nn.Conv2d(128, 256, kernel_size=(3,3), stride=(1,1), padding =(1,1))
        self.conv_layer4 = nn.Conv2d(256, 256, kernel_size=(3,3), stride=(1,1), padding =(1,1))
        self.batch_norm1 = nn.BatchNorm2d(256)
        self.maxpool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.fc_hidden = nn.Linear(101120, self.hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # img = [batch, Cin, W, H]
        batch = src.shape[0]
        C_in = src.shape[1]

        # src = [batch, Cin, w, h]
        # layer 1
        src = self.conv_layer1(src)
        src = F.relu(src)
        src = self.maxpool(src)
        # layer 2
        src = self.maxpool(F.relu(self.conv_layer2(src)))
        # layer 3
        src = F.relu(self.batch_norm1(self.conv_layer3(src)))
        # layer 4
        enc_output = self.maxpool1(F.relu(self.conv_layer4(src)))
        # flatten the last two dimensions of enc_output i.e.
        # [batch, 512, W'xH']
        enc_output = torch.flatten(enc_output, start_dim=1, end_dim=-1)  # [B, 512xW'xH']
        enc_output = enc_output.unsqueeze(0)  # [1, B, 512xW'xH']
        enc_output = self.fc_hidden(enc_output) # [1,B,hid_dim]
        return enc_output


class Decoder(nn.Module):
    def __init__(self, emb_dim, hid_dim, output_dim, n_layers, dropout):#, attention):
        super(Decoder, self).__init__()

        self.emb = nn.Embedding(output_dim, emb_dim)
        self.hid_dim = hid_dim
        self.output_dim = output_dim

        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout, bidirectional=False)
        self.fc = nn.Linear(hid_dim, output_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, trg, hidden, cell):
        trg = trg.unsqueeze(0)  # [1, batch, trg_len or output_dim]
        embed = self.drop(self.emb(trg))  # [1, batch, emb_dim]
        lstm_input = embed
        lstm_output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))   # [1, batch, hid_dim]
        prediction = self.fc(lstm_output)  # [1,B,hidd_dim]
        return prediction.squeeze(0), hidden, cell

class Img2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Img2Seq, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, trg_field, src, trg,  write_flag=False, teacher_force_flag=False, teacher_forcing_ratio=0):
        # src = image of shape [batch, C_in, W, H]
        batch_size = src.shape[0]
        # run the encoder --> get flattened FV of images
        enc_output = self.encoder(src)
        print('enc_output shape: ', enc_output.shape)

        # running Decoder
        # trg = [trg_len, batch_size]
        trg_len = trg.shape[0]
        trg_dim = self.decoder.output_dim
        # to store all separate outputs of individual token
        outputs = torch.zeros(trg_len, batch_size, trg_dim).to(self.device) #[trg_len, batch, output_dim]
        # for each token, [batch, output_dim]

        src = trg[0,:]
        hidden, cell = enc_output, enc_output #[1, B, 512xW'xH']

        if write_flag:
            pred_seq_per_batch = torch.zeros(trg.shape)
            init_idx = trg_field.vocab.stoi[trg_field.init_token]
            pred_seq_per_batch[0,:] = torch.full(src.shape, init_idx)

        for t in range(1, trg_len):

            output, hidden, cell = self.decoder(src, hidden, cell)
            top1 = output.argmax(1)     # [batch_size]

            if write_flag:
                pred_seq_per_batch[t,:] = top1
            # decide if teacher forcing shuuld be used or not
            teacher_force = False
            if teacher_force_flag:
                teacher_force = random.random() < teacher_forcing_ratio

            src = trg[t] if teacher_force else top1

        if  write_flag: return outputs, pred_seq_per_batch
        else: return outputs

Could you add shapes of the random input tensors as well as the model initialization?

After cropping, resizing, and padding, the final input image tensor shape for the encoder will be: [50,3,316,46] ([Batch, channel, W, H]).
Target input shape will be: [100, 50] ([seq_len, batch])

please let me know if you want me to attach my preprocessing, and other scripts too. Thank you very much for helping.

I’ve used this code to reproduce the issue:

model = Img2Seq(Encoder(3, 1, 0.5, 'cpu'), Decoder(1, 1, 1, 1, 0.5), 'cpu')
x = torch.randn(50, 3, 316, 46)
y = torch.zeros(100, 50).long()
out = model(None, x, y)

and assume you are expecting out to be attached to the computation graph.
However, based on your code, out is initialized as torch.zeros:

outputs = torch.zeros(trg_len, batch_size, trg_dim).to(self.device) #[trg_len, batch, output_dim]

and afterwards just returned:

        if  write_flag: return outputs, pred_seq_per_batch
        else: return outputs

so it’s expected that you won’t be able to call backward on this newly initialized torch.zeros tensor.

Thank you for your response. I am getting the same error even if I initialized “torch.zeros” before calling the encoder. I am creating a zeros array to capture the batch output. May I request you to tell me what will be the best way to do it?

The point of initialization wouldn’t matter since you are currently not using outputs at all.
I assume you would like to assign some computed values to it at one point, but this code seems to be missing.

1 Like

Thank you very much for helping me. Just confirming, if I try to print requires_grad for target, output, and weights, for output and weights it will be True, and for Target it will gonna be False as we don’t want it to get updated. Am I correct?

Usually you wouldn’t update the targets, so yes, requires_grad would be False in this case.

1 Like