Pytorch Triplet Loss Model can't predict class label for MNIST dataset though loss value is decreasing

I am trying to predict the labels of MNIST datatset using triplet loss. When I am training the model the loss value is gradually decreasing which indicates that the model is learning. But when I am trying to predict the label even with training data, the accuracy is very low, max 10%. Can anyone help me on how to predict the label if I train my model using triplet loss? I have also used KMeans Clustering on the embedding data but got no luck.

I converted the MNIST dataset for 50*50 pixels, the dataset is balanced. The inputs are images and
dataMNIST.csv contains imageName and corresponding labels.

import os

import torch

import torchvision

from torchvision.datasets import ImageFolder

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np

import pandas as pd

from PIL import Image

import json

from torchvision.transforms import transforms

import random

import torch

import torchvision

from torch.utils.data import DataLoader

from torchvision.datasets import ImageFolder

from torchvision.transforms import transforms

from torchvision.models import resnet18

import torch.nn as nn

import torch.nn.functional as F

import zipfile

import matplotlib.pyplot as plt

import numpy as np

from tqdm.notebook import tqdm

import torch.optim as optim

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

from xgboost import XGBClassifier

from torchvision import transforms

from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score

from sklearn.cluster import KMeans

from numpy import arange


PATH = "/content/train_data/"

torch.manual_seed(2020)

np.random.seed(2020)

random.seed(2020)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device.type == "cuda":

torch.cuda.get_device_name()

embedding_dims = 10

batch_size = 32

epochs = 60

data = pd.read_csv(PATH+"dataMNIST.csv")

train_df, test_df = train_test_split(data, test_size=0.10, train_size=0.90)

train_df, validation_df = train_test_split(train_df, test_size=0.1, train_size=0.9)

train_df= train_df.reset_index(drop= True)

test_df = test_df.reset_index(drop= True)

validation_df = validation_df.reset_index(drop= True)

test_label = test_df.iloc[:, 1].values

class MNIST(Dataset):
    def __init__(self, df, path, train=True, transform=None):
        self.data_csv = df
        self.is_train = train
        self.transform = transform
        self.path = path
#         self.to_pil = transforms.ToPILImage()

        if self.is_train:
            self.images = df.iloc[:, 0].values
            self.labels = df.iloc[:, 1].values
            self.index = df.index.values
        else:
            self.images = df.iloc[:, 0].values

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):

        anchor_image_name = self.images[item]
        anchor_idx = anchor_image_name.find(":")
        anchor_image_path = " "
        # anchor_image_path = self.path + anchor_image_name.split(":")[1].split("_")[0]+ '/' + anchor_image_name
        # anchor_image_path = self.path +anchor_image_name.split("_")[1]+ '/' + anchor_image_name
        if anchor_idx > 0:
          anchor_image_path = self.path +anchor_image_name.split(":")[1].split("_")[0]+ '/' + anchor_image_name
        else:
          anchor_image_path = self.path +anchor_image_name.split("_")[1]+ '/' + anchor_image_name
        ##### Anchor Image #######
        anchor_img = Image.open(anchor_image_path).convert('1')
        if self.is_train:
            anchor_label = self.labels[item]
            positive_list = self.index[self.index!=item][self.labels[self.index!=item]==anchor_label]
            positive_item = random.choice(positive_list)
            positive_image_name = self.images[positive_item]
            positive_idx = positive_image_name.find(":")
            positive_image_path = " "
            if positive_idx > 0:
              positive_image_path = self.path +positive_image_name.split(":")[1].split("_")[0]+ '/' + positive_image_name
            else:
              positive_image_path = self.path +positive_image_name.split("_")[1]+ '/' + positive_image_name
            # positive_image_path = self.path +positive_image_name.split(":")[1].split("_")[0]+ '/'  + positive_image_name
            # positive_image_path = self.path +positive_image_name.split("_")[1]+ '/' + positive_image_name
            positive_img = Image.open(positive_image_path).convert('1')
            #positive_img = self.images[positive_item].reshape(28, 28, 1)
            negative_list = self.index[self.index!=item][self.labels[self.index!=item]!=anchor_label]
            negative_item = random.choice(negative_list)
            negative_image_name = self.images[negative_item]
            negative_idx = negative_image_name.find(":")
            negative_image_path = " "
            if negative_idx > 0:
              negative_image_path = self.path +negative_image_name.split(":")[1].split("_")[0]+ '/' + negative_image_name
            else:
              negative_image_path = self.path +negative_image_name.split("_")[1]+ '/' + negative_image_name

            # negative_image_path = self.path +negative_image_name.split("_")[1]+ '/' + negative_image_name
            # negative_image_path = self.path +negative_image_name.split(":")[1].split("_")[0]+ '/'  + negative_image_name
            negative_img = Image.open(negative_image_path).convert('1')
            #negative_img = self.images[negative_item].reshape(28, 28, 1)
            if self.transform!=None:
                anchor_img = self.transform(anchor_img)
                positive_img = self.transform(positive_img)
                negative_img = self.transform(negative_img)
            return anchor_img, positive_img, negative_img, anchor_label
        else:
            if self.transform:
                anchor_img = self.transform(anchor_img)
            return anchor_img

train_ds = MNIST(train_df, PATH,
                 train=True,
                 transform=transforms.Compose([
                     transforms.ToTensor()
                 ]))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)

validation_ds = MNIST(validation_df,PATH, train=False, transform=transforms.ToTensor())

validation_loader = DataLoader(validation_ds, batch_size=batch_size, shuffle=False, num_workers=4,pin_memory=True)

test_ds = MNIST(test_df,PATH, train=False, transform=transforms.ToTensor())

test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4,pin_memory=True)


class Network(nn.Module):
    def __init__(self, emb_dim=10):
        super(Network, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3),
            nn.PReLU(),
            nn.MaxPool2d(2, stride=2),
            #nn.Dropout(0.3),
            nn.Conv2d(8, 16, 3),
            nn.PReLU(),
            nn.MaxPool2d(2, stride=2),
            #nn.Dropout(0.3)
            nn.Conv2d(16, 32, 3),
            nn.PReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(32*4*4, 128),
            nn.PReLU(),
            nn.Linear(128, 64),
            nn.PReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.conv(x)
        # print(x.shape)
        x = x.view(-1, 32*4*4)
        x = self.fc(x)
        return x


def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight)

model = Network(embedding_dims)

model.apply(init_weights)

model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

criterion = nn.TripletMarginLoss(margin=1.0, p=2)


model.train()
min = 1.0
max = 0.0
loss_all = []
for epoch in tqdm(range(epochs), desc="Epochs"):
    running_loss = []
    for step, (anchor_img, positive_img, negative_img, anchor_label) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)

        optimizer.zero_grad()
        anchor_out = model(anchor_img)
        positive_out = model(positive_img)
        negative_out = model(negative_img)

        loss = criterion(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()

        running_loss.append(loss.cpu().detach().numpy())
    print("Epoch: {}/{} - Loss: {:.4f}".format(epoch+1, epochs, np.mean(running_loss)))
    loss_all.append(np.mean(running_loss))

    validation_results = []
    validation_labels = []
    total_correct = 0
    total_instances = 0

    model.eval()
    with torch.no_grad():
        for img, _, _, label in tqdm(train_loader):
            output = model(img.to(device))
            # np.append(validation_results, output.cpu().numpy())
            validation_results.append(output.cpu().numpy())
            validation_labels.append(label)
            tq = label.to(device)
            # print(model(img.to(device)).cpu().numpy().shape)
            predictions = output.argmax(dim=1, keepdim = True)
            # _, predictions = torch.min(model(img.to(device)).data, 1)
            # print(predictions)
            # correct_predictions = sum(predictions==tq).item()
            # total_correct+=correct_predictions
            total_correct += predictions.eq(tq.view_as(predictions)).sum().item()
            total_instances+=len(img)
    accuracy = (total_correct/len(train_loader.dataset))*100


    # validation_results = np.concatenate(validation_results)
    # validation_labels = np.concatenate(validation_labels)
    # kmeans = KMeans(n_clusters=len(np.unique(validation_labels))).fit(validation_results)
    # pred_labels = kmeans.predict(validation_results)
    # accuracy = accuracy_score(validation_labels, pred_labels) * 100


    print("Correct: ", total_correct, " Total: ", len(train_loader.dataset))
    print("Accuracy: {:.3F}".format(accuracy))


    if max< accuracy:
      max = accuracy
      torch.save({"model_state_dict2": model.state_dict(),
            "optimzier_state_dict2": optimizer.state_dict()
           }, PATH+"trained_model_curves_new2.pt")

      sm = torch.jit.script(model)
      sm.save(PATH+"trained_model_curves_new2_cpp.pt")

    if min>=np.mean(running_loss):
      min = np.mean(running_loss)
      torch.save({"model_state_dict": model.state_dict(),
            "optimzier_state_dict": optimizer.state_dict()
           }, PATH+"trained_model_curves_new.pt")

      sm = torch.jit.script(model)
      sm.save(PATH+"trained_model_curves_new_cpp.pt")


print(min)
print(max)