Loss function gets stuck at some epochs

Hi, I am using contrastive loss for a set of 3d medical images. The loss function gets two augmented versions of an image and computes the similarity score. the problem is that when I train it on 3D images, it dose not change too much plus that it is unstable in the sense that sometimes it decreases and sometimes it increases. The domain of increasing and decreasing is not too much. But I expected it decreases for some epochs and then becomes fixed. Actually I tried it fro 2D and I observed that it decreases and then after a while, it became fixed. As far as I understood this behavior is normal for contrastive loss. But I am wondering why in 3d setting it bounces around value and it does not the same behavior as 2D. one reason could be the batch size. in contrastive loss it is important we have large natch sizes but in 3d medical images, I can not have such batch sizes. At most 4. I tried it for batch_szie = 2and 4 and I observed that after increasing the batch size loss increased. While in 2D I observed that it droped down significantly when I doubled the batch size. So, I do not know what is this strange behavior of this loss in 3D medical images. Another problem is that sometimes it gets stuck at some epoch and it does not get rid of that epoch. It happens from time to time and not regularly.
This is my code in 2d:

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize([0.491, 0.482, 0.447], [0.247, 0.243,0.261])])
class PairwiseAg(object):
    def __init__(self, transform):
        self.transform = transform
    def __call__(self, vol): 
        voli = self.transform(vol)
        volj = self.transform(vol)
        return(voli,volj)
dataset =datasets.CIFAR10('./data', train=True, transform=PairwiseAg(transform), download=True)
ss = ShuffleSplit(n_splits=1, test_size= 0.2,random_state=0)
for train_idx, val_idx in ss.split(dataset):
    train_idx = train_idx
    val_idx = val_idx
train_set = Subset(dataset,train_idx)
val_set = Subset(dataset,val_idx)
train_loader = DataLoader(train_set,batch_size=512,shuffle=True,num_workers=2,drop_last=False)
val_loader = DataLoader(val_set,batch_size=512,shuffle=True,num_workers=2,drop_last=False)
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.resnet18(pretrained=False, progress=True)
model.fc = nn.Linear(in_features=512,out_features=128)
model = model.to(device)
class NT_Xent(nn.Module):
    def __init__(self, temperature, device):
        super(NT_Xent, self).__init__()
        self.temperature = temperature
        self.device = device
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)
    def forward(self, z_i, z_j):
        self.batch_size= z_i.size()[0]
        self.mask = torch.ones((self.batch_size * 2, self.batch_size * 2), dtype=bool)
        self.mask = self.mask.fill_diagonal_(0)
        for i in range(self.batch_size):
            self.mask[i, self.batch_size + i] = 0
            self.mask[self.batch_size + i, i] = 0
        z_i= F.normalize(z_i, dim=1)
        z_j= F.normalize(z_j, dim=1)
        p1 = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature)
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(self.batch_size * 2,1)
        negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)
        labels = torch.zeros(self.batch_size * 2).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= 2 * self.batch_size
        return(loss)
optimizer = optim.Adam(model.parameters(),lr=5e-1)
criterion = NT_Xent(0.7,device)
def train(epoch):
    model.train()
    total_loss = 0
    for i,(X,_) in enumerate(train_loader):
        X1 = X[0].to(device)
        X2 = X[1].to(device)
        h1 = model(X1)
        h2 = model(X2)
        loss = criterion(h1.float(),h2.float())
        total_loss+=loss.item()
        optimizer.zero_grad()
        loss.backwar()
        optimizer.step()
    return(total_loss/len(train_loader.dataset))

def val():
    model.eval()
    total_loss = 0
    for i, (X,_) in enumerate(val_loader):
        X1 = X[0].to(device)
        X2 = X[1].to(device)
        h1 = model(X1)
        h2 = model(X2)
        loss = criterion(h1.float(),h2.float())
        total_loss+=loss.item()
    return(total_loss/len(val_loader.dataset))

def main(num_epochs):
    for epoch in range(num_epochs):
        tr_loss = train(epoch)
        val_loss = val()
       
output:
curren_lr:0.5
Epoch: 0 train_loss:0.0131  val_loss:0.0131
Epoch: 1 train_loss:0.0127  val_loss:0.0129
Epoch: 2 train_loss:0.0126  val_loss:0.0129
Epoch: 3 train_loss:0.0126  val_loss:0.0129
Epoch: 4 train_loss:0.0125  val_loss:0.0129
Epoch: 5 train_loss:0.0124  val_loss:0.0128
Epoch: 6 train_loss:0.0124  val_loss:0.0126
Epoch: 7 train_loss:0.0123  val_loss:0.0127
Epoch: 8 train_loss:0.0123  val_loss:0.0125
Epoch: 9 train_loss:0.0123  val_loss:0.0127
Epoch: 10 train_loss:0.0123  val_loss:0.0126
Epoch: 11 train_loss:0.0122  val_loss:0.0125```

And these are my results in 3D:
```Epoch:67 | Train_Loss:2.5502 | Val_Loss:4.0081
Validation loss does not decrease from 2.5172, checks_without_progress:24
Epoch: 68/100
lr = 0.00000100
Epoch:68 | Train_Loss:2.5834 | Val_Loss:3.0598
Validation loss does not decrease from 2.5172, checks_without_progress:25
Epoch: 69/100
lr = 0.00000100
Epoch:69 | Train_Loss:2.6317 | Val_Loss:3.0419
Validation loss does not decrease from 2.5172, checks_without_progress:26
Epoch: 70/100
lr = 0.00000100
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Epoch:70 | Train_Loss:2.6459 | Val_Loss:3.8916
Validation loss does not decrease from 2.5172, checks_without_progress:27
Epoch: 71/100
lr = 0.00000100
Epoch:71 | Train_Loss:2.6370 | Val_Loss:2.8522
Validation loss does not decrease from 2.5172, checks_without_progress:28
Epoch: 72/100
lr = 0.00000100
Epoch:72 | Train_Loss:2.6422 | Val_Loss:3.1595
Validation loss does not decrease from 2.5172, checks_without_progress:29
Epoch: 73/100
lr = 0.00000100
Epoch:73 | Train_Loss:2.5912 | Val_Loss:3.2604
Validation loss does not decrease from 2.5172, checks_without_progress:30
Epoch: 74/100
lr = 0.00000100
Epoch:74 | Train_Loss:2.6797 | Val_Loss:4.2577
Validation loss does not decrease from 2.5172, checks_without_progress:31
Epoch: 75/100
lr = 0.00000100
Epoch:75 | Train_Loss:2.6413 | Val_Loss:3.3625
Validation loss does not decrease from 2.5172, checks_without_progress:32
Epoch: 76/100
lr = 0.00000100
Epoch:76 | Train_Loss:2.5942 | Val_Loss:3.9586
Validation loss does not decrease from 2.5172, checks_without_progress:33
Epoch: 77/100
lr = 0.00000100
Epoch:77 | Train_Loss:2.6808 | Val_Loss:3.6534
Validation loss does not decrease from 2.5172, checks_without_progress:34
Epoch: 78/100
lr = 0.00000100
Epoch:78 | Train_Loss:2.6145 | Val_Loss:3.7018
Validation loss does not decrease from 2.5172, checks_without_progress:35
Epoch: 79/100
lr = 0.00000100
Epoch:79 | Train_Loss:2.6323 | Val_Loss:3.4321
Validation loss does not decrease from 2.5172, checks_without_progress:36
Epoch: 80/100
lr = 0.00000100
Epoch:80 | Train_Loss:2.6303 | Val_Loss:3.9152
Validation loss does not decrease from 2.5172, checks_without_progress:37
Epoch: 81/100
lr = 0.00000100
Epoch:81 | Train_Loss:2.6353 | Val_Loss:3.1611
Validation loss does not decrease from 2.5172, checks_without_progress:38
Epoch: 82/100
lr = 0.00000100
Epoch:82 | Train_Loss:2.5255 | Val_Loss:3.4790
Validation loss does not decrease from 2.5172, checks_without_progress:39
Epoch: 83/100
lr = 0.00000100
Epoch:83 | Train_Loss:2.5346 | Val_Loss:3.3782
Validation loss does not decrease from 2.5172, checks_without_progress:40
Epoch: 84/100
lr = 0.00000100
Epoch:84 | Train_Loss:2.6319 | Val_Loss:4.0758
Validation loss does not decrease from 2.5172, checks_without_progress:41
Epoch: 85/100
lr = 0.00000100
Epoch:85 | Train_Loss:2.7081 | Val_Loss:3.8005
Validation loss does not decrease from 2.5172, checks_without_progress:42
Epoch: 86/100
lr = 0.00000100
Epoch:86 | Train_Loss:2.5988 | Val_Loss:3.1003
Validation loss does not decrease from 2.5172, checks_without_progress:43
Epoch: 87/100
lr = 0.00000100
Epoch:87 | Train_Loss:2.5835 | Val_Loss:3.0547
Validation loss does not decrease from 2.5172, checks_without_progress:44
Epoch: 88/100
lr = 0.00000100
Epoch:88 | Train_Loss:2.5992 | Val_Loss:4.6173
Validation loss does not decrease from 2.5172, checks_without_progress:45
Epoch: 89/100
lr = 0.00000100
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Epoch:89 | Train_Loss:2.6231 | Val_Loss:3.9014
Validation loss does not decrease from 2.5172, checks_without_progress:46
Epoch: 90/100
lr = 0.00000100
Epoch:90 | Train_Loss:2.5815 | Val_Loss:3.4956
Validation loss does not decrease from 2.5172, checks_without_progress:47
Epoch: 91/100
lr = 0.00000010```