Accuracy, training, test loss doesn't change throughout all the epochs

I am working on a text classification problem with a binary output 0 or 1. The accuracy, train loss and test loss remains the same. The accuracy is exact the same for all the epochs. All the steps looks very correct. I had tried out several ways to figure out what is going wrong. Nothing actually worked. Please help.

class RNN(nn.Module):
    def __init__(self, num_layers, num_classes, input_size, hidden_size,vocab,dropout):
        super(RNN,self).__init__()
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.sequence_len = sequence_len
        self.embedding = nn.Embedding(len(vocab),input_size)
        nn.init.xavier_normal_(self.embedding.weight)
        self.rnn = nn.RNN(input_size, hidden_size,num_layers,dropout = dropout, nonlinearity = 'tanh', batch_first=True, bias = True, bidirectional = False)
        self.linear = nn.Linear(hidden_size, 1)
        nn.init.xavier_normal_(self.linear.weight)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        lens = list(map(len, x))
        padded = pad_sequence(x, batch_first=True)
        output_embedding=self.embedding(padded)   

        packed = pack_padded_sequence(input = output_embedding,lengths = lens, batch_first=True, enforce_sorted=False)
        input_rnn = packed
        h011 = torch.zeros(1,32,3)
        output_11, hidden_11 = self.rnn(input_rnn,h011)
        output_padded, output_lengths = pad_packed_sequence(output_11, batch_first=True)
        final_output_11 = self.linear(hidden_11)        
        prob_11 = self.sigmoid(final_output_11)
        return output_padded, hidden_11, prob_11

 num_layers = 1
   num_classes = 2
   input_size = 5
   hidden_size = 3
   criterion = nn.BCELoss()
   sequence_len = 1
   dropout = 0.5
   rnn = RNN(num_layers, num_classes, input_size, hidden_size,vocab,dropout)
   epochs = 10
   lr = 0.01
   weight_decay=0.011
   def train_loop(model,criterion,optimizer,train_loader,valid_loader,epochs):
    train_losses= []
    valid_losses= []

    for epoch in range(epochs):
        train_loss=0
        for label,text in train_loader:
            
            output,hidden,prob = rnn.forward(text)
            prob = torch.tensor([item.item() for sublist in prob for item in sublist],dtype = torch.float32)
            label = torch.tensor(label, dtype = torch.float32)
            loss=criterion(prob,label)
            optimizer.zero_grad()
            loss.requires_grad = True
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss=train_loss/len(train_loader)
        
        
        valid_loss=0
        
        with torch.no_grad():
            correct=0
            total=0
            for label,text in valid_loader:
                output,hidden,prob = rnn.forward(text)
                prob = torch.tensor([item.item() for sublist in prob for item in sublist],dtype = torch.float32)
                label = torch.tensor(label, dtype = torch.float32)
                loss=criterion(prob,label)
                valid_loss += loss.item()
                p = torch.tensor([1 if i > 0.5 else 0 for i in prob.data], dtype = torch.float32)
                #predicted = torch.max(p, 1)
                
                total += label.size(0)
                
                correct += (p == label).sum().item()
            valid_loss=valid_loss/len(valid_loader)
            accuracy = 100 * correct / total
            print(accuracy)
        scheduler.step(accuracy)
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        
        print(f'Epoch {epoch+1:<2d}/{epochs} --> Train Loss: {train_loss:.4f} |  Valid Loss: {valid_loss:.4f}')

   from torch.optim.lr_scheduler import StepLR
   from torch.optim.lr_scheduler import ReduceLROnPlateau
   optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)
   scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.4, patience=5, 
   verbose=True) # need to change scheduler.step

   train_loop(rnn,criterion,optimizer,train_loader,valid_loader,epochs)

Re-wrapping a tensor in another tensor as well as calling item() on it will detach it from the computation graph:

prob = torch.tensor([item.item() for sublist in prob for item in sublist],dtype = torch.float32)

and the subsequent backward() call would raise an issue.
However, since you’ve also assigned loss.requires_grad = True, the error is masked.

I’m unsure what shape prob has before and after, but you would need to use e.g. tensor.view in case you need to change its shape.

@ptrblck thank you very much for pointing it out.

prob = torch.tensor([item.item() for sublist in prob for item in sublist],dtype = torch.float32)

This above code didn’t let loss function calculate the gradients. Finally I used torch.flatten() to reshape the tensor.

But, now I have another question. Still my weights looks unupdated after calculating gradients and even after every epoch the weights remain the very same. I used this code to check it.

list(rnn.parameters())[0].clone()

Where I am I going wrong? Please check the screenshot.

To make sure the computation graph is not detached anymore you could check the .grad attribute of model parameters before and after the first loss.backward() call.
Before it, .grad should return None (as no gradients were computed yet), and afterwards it should show a valid tensor.