Loss requires grad false

Hello all,
I have been such an issue for the first time. The loss computed have requires_grad = False by default but it should be True, I have no idea why this is happening. Apart from that even if I explicitly change the requires grad to true, the model parameters are still not getting updated. Please look model and training code below. I have checked the requires grad of model parameters and they are True.
Regards

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class LSTMClassifier(nn.Module):
    def __init__(self, ):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(1, 64, num_layers=1,batch_first =True )
        self.hidden2out = nn.Linear(64,2)
        self.dropout_layer = nn.Dropout(p=0.3)
    def forward(self, data):
        outputs, (ht, ct) = self.lstm(data, None)
        output = self.dropout_layer(ht[-1])
        output = self.hidden2out(output)
        return output
model = LSTMClassifier().cuda()
def train(model,dataloaders,num_epochs,optimizer,patience = None):
    i = 0
    phase1 = dataloaders.keys()
    criterion = nn.CrossEntropyLoss().cuda()
    train_loader = dataloaders['train']
    if(torch.cuda.is_available()):
        device = 'cuda'
    else:
        device = 'cpu'
    if(patience!=None):
        earlystop = EarlyStopping(patience = patience,verbose = True)
    for epoch in range(num_epochs):
        print('Epoch:',epoch)
        epoch_metrics = {"loss": [], "acc": []}
        for phase in phase1:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            for  batch_idx, (data, target) in enumerate(dataloaders[phase]):
                data, target = Variable(data), Variable(target)
                data = data.type(torch.FloatTensor).to(device)
                target = target.type(torch.LongTensor).to(device)
                optimizer.zero_grad()
                data.requires_grad = True
                output = model(data)
                output.requires_grad = True
                loss = criterion(output, target)
                loss.requires_grad = True
                acc = 100 * (output.detach().argmax(1) == target).cpu().numpy().mean()
                epoch_metrics["loss"].append(loss.item())
                epoch_metrics["acc"].append(acc)
                sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [Loss: %f (%f), Acc: %.2f%% (%.2f%%)]"
                % (
                    epoch,
                    num_epochs,
                    batch_idx,
                    len(dataloaders[phase]),
                    loss.item(),
                    np.mean(epoch_metrics["loss"]),
                    acc,
                    np.mean(epoch_metrics["acc"]),
                    )
                )

                if(phase =='train'):
                    loss.backward()
                    optimizer.step()
            epoch_acc = np.mean(epoch_metrics["acc"])
            epoch_loss = np.mean(epoch_metrics["loss"])
            if(phase == 'val' and patience !=None):
                earlystop(epoch_loss,model)
                if(earlystop.early_stop):
                    print("Early stopping")
                    model.load_state_dict(torch.load('./checkpoint.pt'))
                    print('{} Accuracy: {}'.format(phase,epoch_acc.item()))
                    break
        print('{} Accuracy: {}'.format(phase,epoch_acc.item()))

Your code works fine using dummy inputs and targets:

class LSTMClassifier(nn.Module):
    def __init__(self):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(1, 64, num_layers=1,batch_first =True )
        self.hidden2out = nn.Linear(64,2)
        self.dropout_layer = nn.Dropout(p=0.3)
    def forward(self, data):
        outputs, (ht, ct) = self.lstm(data, None)
        output = self.dropout_layer(ht[-1])
        output = self.hidden2out(output)
        return output

model = LSTMClassifier()

x = torch.randn(1, 1, 1)
target = torch.randint(0, 2, (1,))
criterion = nn.CrossEntropyLoss()

output = model(x)
loss = criterion(output, target)
loss.backward()
print(model.lstm.weight_ih_l0.grad.abs().sum())
> tensor(0.5159)

There are some minor issues in your code, which should be unrelated or would throw an error in newer PyTorch versions:

  • Variables are deprecated since 0.4, so you can use tensors directly
  • setting the requires_grad attribute on non-leaf variables should throw an error, so remove output.requires_grad = True and loss.requires_grad = True

Yes Sir, there is something wrong with my data but it is normal tensor data, Can you tell what can be wrong in it. I had to set requires grad True for loss as the model was returning it with requires grad False.
Thanks for quick reply.

The model in your code snippet returns an output which requires gradients. Is it different on your setup using your minimal code snippet?

It is same. I have tried it for a different dataset and random sequence as well it returns requires gradients. But for the dataset, I am applying it to it returns False. The dataset is simple with 1 feature and 16 temporal segments. Shape (batch_size,1,16). And it is binary classification dataset. Is there anything I should change in the dataset which is creating an issue.

It shouldn’t be necessary to change anything in the dataset, as the input shouldn’t change the grequires_grad flag of your output.
If the simple model gives you the right answer for synthetic data, I would suggest to add your data pipeline step by step and observe at which point the computation graph gets detached.

Ok sir, I have one last doubt, I have noticed that if we change the shape of any variable in the forward pass, it takes up a lot of gpu memory so the solution I found online was to detach the variable as in example:


class FAN(nn.Module):
    def __init__(
        self,):
        super(FAN, self).__init__()
        self.encoder = Encoder()

    def forward(self, x):
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w).detach()
        x = self.encoder(x)
        return x

and


class FAN(nn.Module):
    def __init__(
        self,):
        super(FAN, self).__init__()
        self.encoder = Encoder()

    def forward(self, x):
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        x = self.encoder(x)
        return x

Will the backpropogation graph in both cases is same? Or the first method is not correct and the weights will not get updated normally?
Thanks for such quick replies. Your help is appreciated

detach() on a tensor will, as the name suggests, detach the tensor from the computation graph, so that the gradients will be stopped at this point. All operations performed before the detach call will not get any gradients from operations performed on this tensor.
If you are detaching manually somewhere in the model, this might explain, why the output doesn’t require gradients.

The increase in memory is expected, as the backward pass needs intermediate tensors to properly calculate the gradients.
If you are running out of memory, try to lower the batch size or use torch.utils.checkpoint to trade compute for memory.

1 Like

Thanks a lot sir, got it.

Hello,

I got the same problem if I am using a customized loss function. For instance:

def myloss(x):
    return torch.mean(x)

However, I noticed that the output of lstm doesn’t have grad, so could you please tell me what I should do this case?

Thanks

Could you print the .grad_fn attribute of your model output and of your custom loss function output, please?