Model gives same output for every input

Hi, I was trying to finetune SentBert (writing code from scratch for finetuning) and after training I observed that my model was giving same output for all the input pairs. I have tried tuning the hyperparameters too, have also tried using scheduler to decrease the learning rate as the loss approaches to minima. I have also tried various optimizers. But I can see that in any of the case the loss is fluctuating in the same range and after a few epochs the model gives same output for all the inputs. Even before completion of 1 epoch of training, the model gives same cos-similarity score = 1 for all inputs.

I face this problem in almost all the models that I train. Can I get help if someone has the reason of why does it happen? I have observed this type of behavior mostly in the cases where we save the model’s states with torch.save().
@ptrblck Can you please check this once?

class BERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased').to(device)

    def forward(self, sent):   # sent: batch of sentences , example = ["rishabh is good", "he is bad", "alright!"]
        # tokens = self.tokenizer.tokenize(sent)
        # tokens = tokens[:512]
        # input_ids = torch.tensor(self.tokenizer.encode(tokens, pad_to_max_length=True, add_special_tokens=False, add_space_before_punct_symbol=True)).unsqueeze(0).to(device)  # Batch size 1
        encoded = self.tokenizer(sent, padding=True, truncation=True, return_tensors="pt").to(device)
        # print(encoded)
        outputs = self.model(**encoded)
        # print(outputs[0].shape)
        # print(outputs[1].shape)
        last_hidden_states = outputs[0]
        pooled = torch.sum(last_hidden_states,dim=-2) / (last_hidden_states.shape[-2])
        # print(pooled.shape)

        return pooled


class Siamese(nn.Module):
    def __init__(self):
        super().__init__()
        # self.out = nn.Linear(2*768, 2)
        self.out = nn.Linear(3*768, 2)

    def forward(self, X1, X2):
        res = torch.cat([X1,X2,abs(X1-X2)],dim=-1)
        res = self.out(res)
        return res

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BERT()
        self.sia = Siamese()

    def forward(self, X1, X2):
        X1 = self.bert(X1)
        X2 = self.bert(X2)
        res = self.sia(X1,X2)
        return res
net = Network().to(device)

criterion = nn.CrossEntropyLoss().cuda()
learning_rate = 0.0001
optimizer = optim.Adam(params=net.parameters(), lr=learning_rate, weight_decay=0, betas=(0.9,0.99))
# first = f1.readlines()
# second = f2.readlines()
# labels = f3.readlines()
# length = len(labels)
data = dataset()
batch_size = 8
loader = DataLoader(data, batch_size=batch_size)

para = net.parameters()
# scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
loss_values = []
epoch=0
for epoch in range(2):
    # scheduler.step()
    running_loss = 0.0
    # print('Epoch:', epoch, 'LR:', scheduler.get_lr())
    for batch in loader:
        s1, s2, label = batch[0], batch[1], batch[2]
        label = list(map(int, label))
        label = torch.tensor(label).long().to(device)
        # print(label)
        optimizer.zero_grad()
        output = net(s1, s2)
        # print("label, out", label, output)
        loss = criterion(output, label).cuda()
        print("LOSS:", loss)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    loss_values.append(running_loss/batch_size)
#f1.close()
#f2.close()
#f3.close()
PATH = "fine.pt"
torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss_values,
            }, PATH)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.plot(loss_values,"b")
plt.show()

print('Finished Training')