Ground-truth label normalization

The model might not be necessarily wrong.
As a quick test, you could increase the weights of a specific class by a huge amount and check, if your model would be able to predict this class.
E.g. increase the weight for class1 to 100 and recheck the output.
If your model still only predicts class0, there might be another issue in the code.
Could you post the model definition in this case, so that we could have a look?

when i tried weights = torch.tensor([0.0,1.0,0.0,0.0]).to(device)
then weights = torch.tensor([0.0,0.0,1.0,0.0]).to(device)
then weights = torch.tensor([0.0,0.0,0.0,1.0]).to(device)

IT GAVE ME THE SAME OUTPUTS [3], AND THE FIGURE IS Always BACKGROUND

HERE is MY CODE

import nibabel as nib
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from skimage.transform import rescale, resize
from torch.utils.data import Dataset, DataLoader

from torch import nn
import torch.optim as optim 
from sklearn.model_selection import train_test_split
from torchsummary import summary

def load_nii(img_path):
    """
    Function to load a 'nii' or 'nii.gz' file, The function returns
    everyting needed to save another 'nii' or 'nii.gz'
    in the same dimensional space, i.e. the affine matrix and the header

    Parameters
    ----------

    img_path: string
    String with the path of the 'nii' or 'nii.gz' image file name.

    Returns
    -------
    Three element, the first is a numpy array of the image values,
    the second is the affine transformation of the image, and the
    last one is the header of the image.
    """
    nimg = nib.load(img_path)
    return nimg

def Preprocessing(data):
    # normaliser les données
  
   return ((data-data.min())/(data.max()-data.min()))

def Resize(image,width,height):
    img_crop={}
    W,H=image.shape
    start_x=round((width-W)//2)
    start_y=round((height-H)//2)
    img_crop = np.zeros(shape=(width,height))
    img_crop[start_x:start_x+W,start_y:start_y+H]=image
    return img_crop 

  
def Rescale(image,header,pix_scal_gol):
    ###### rescaler les données ######
    
    x,y=image.shape
    px,py,pz=header.get_zooms()
    fx=round(pix_scal_gol/px)
    fy=round(pix_scal_gol/py)
    image_res=resize(image,[x*fx,y*fy])
    
    return image_res


def image_imshow(img_,img_seg):
    # affichage de l'image et la groudtruth
    plt.figure()
    fig, ax = plt.subplots(1,2)
    ax[0].imshow(img_)
    ax[1].imshow(img_seg)
    
def _imshow(img_,img_gt,img_seg):
    # affichage de l'image et la groudtruth
    plt.figure()
    fig, ax = plt.subplots(1,3)
    ax[0].imshow(img_)
    ax[1].imshow(img_gt)
    ax[2].imshow(img_seg)
    


def im_imshow(img):
    # affichage des images
    plt.figure()
    plt.imshow(img)    
    
    
    
def Hist(img):
    # affichage de l'histogramme 
    plt.hist(img.ravel(), bins=256,fc='k', ec='k')

def deviser(data, num):
    portion = len(data) / float(num)
    out = []
    last = 0.0

    while last < len(data):
        out.append(data[int(last):int(last + portion)])
        last += portion

    return out 

def conversion(prediction, th=0.01):
    
    #make the value's data between 0 and 1
    
    all_pred = prediction[:]
    low_values_idx = all_pred < th
    all_pred[low_values_idx] = 0  # make all pixels lower then th at 0
    
    low_values_idx = all_pred >= th
    all_pred[low_values_idx] = 1 # make all pixels higher or equal to th at 1
    
    return all_pred
           
def dice_coeff(prediction, label):
    
    # the dice between the prédiction and the label 
    # if the dice is 1 or nearly to 1 so we have a good prediction
    # if the dice is lower than one or nearly to 0 so we have a bad prediction of the groudtruth(ie segmentation)
    
    num = prediction.size(0)
    m1 = prediction.view(num, -1)  
    m2 = label.view(num, -1)  
    intersection = (m1 * m2).sum()
    dice=(2. * intersection +1 ) / (m1.sum() + m2.sum() + 1)

    return dice   

data={}
data_ed=[]
headerr=[]
header_gt=[]
gt_ed=[]

for idx in range(1,101):
    # charger les données et les stocker  
    
    D = os.path.join('/home/iadi.lan/mtir/Bureau/training/','patient%03.0d'%idx)
    D=os.listdir(D)
    D.sort()
    data[idx]={}
    data[idx]['ed_data']=D[2]
    data[idx]['ed_gt']=D[3]
    data[idx]['es_data']=D[4]
    data[idx]['es_gt']=D[5]
    
    data_ed.append(np.asanyarray(load_nii('/home/iadi.lan/mtir/Bureau/training/'+'patient%03.0d/'%idx+D[4]).dataobj))
    headerr.append(load_nii('/home/iadi.lan/mtir/Bureau/training/'+'patient%03.0d/'%idx+D[4]).header)
    gt_ed.append(np.asanyarray(load_nii('/home/iadi.lan/mtir/Bureau/training/'+'patient%03.0d/'%idx+D[5]).dataobj))
    header_gt.append(load_nii('/home/iadi.lan/mtir/Bureau/training/'+'patient%03.0d/'%idx+D[5]).header)
    


data_train=[] 
data_train_gt=[]
data_test=[] 
data_test_gt=[]

Data= deviser(data_ed,5)    
G_truth=deviser(gt_ed,5)


for i in range(0,5):
    
    X_t,X_ts,y_t,y_ts=train_test_split(Data[i], G_truth[i] , test_size=0.2, shuffle=False)
    data_train+=X_t
    data_test+=X_ts
    data_train_gt+=y_t
    data_test_gt+=y_ts


donne=[]
size=[]
size_test=[]
taille=[]
x={}
y={}
z={}
x_test={}
y_test={}
z_test={}
data_train_norm=[]
data_test_norm=[]
gt_train_=[]
gt_test_=[]

for k in range(0,80):
    data_train_norm.append(Preprocessing(data_train[k]))
    gt_train_.append(data_train_gt[k])
    size.append(data_train_norm[k].shape)
    x[k],y[k],z[k]=size[k]
    
for k in range(0,20):
    
    data_test_norm.append(Preprocessing(data_test[k]))
    gt_test_.append(data_test_gt[k])
    size_test.append(data_test_norm[k].shape)
    x_test[k],y_test[k],z_test[k]=size_test[k]
    
    
image_resized_train=[]
image_resized_test=[]
image_resized_train_gt=[]
image_resized_test_gt=[]


for i in range(0,80):
    for j in range(z[i]):
        
        image_resized_train.append(rescale(data_train_norm[i][:,:,j],1.3, anti_aliasing=False)) 
        image_resized_train_gt.append(rescale(gt_train_[i][:,:,j],1.3, anti_aliasing=False)) 
        

for i in range(0,20):
    for j in range(z_test[i]):
        
        image_resized_test.append(rescale(data_test_norm[i][:,:,j],1.3, anti_aliasing=False))
        image_resized_test_gt.append(rescale(gt_test_[i][:,:,j],1.3, anti_aliasing=False))
  
final_train=[] 
final_test=[]
final_train_gt=[] 
final_test_gt=[]
img_train=[]  
img_test=[]
img_train_gt=[]  
img_test_gt=[]
width=557
height=667
data_tensor_test=[]
data_tensor_train=[]
gt_tensor_test=[]
gt_tensor_train=[]

for i in range(0,len(image_resized_train)):  
    
    img_train.append(Resize(image_resized_train[i],width,height))
    
    final_train.append(resize(img_train[i],(256,256)))

    data_tensor_train.append(torch.from_numpy(np.array(final_train[i])).float())
    # groudtruth train 
    
    img_train_gt.append(Resize(image_resized_train_gt[i],width,height))
    
    final_train_gt.append(resize(img_train_gt[i],(256,256)))

    gt_tensor_train.append(torch.from_numpy(np.array(final_train_gt[i])).float())
    
    #image_imshow(final_train[i], final_train_gt[i])
   


      
for i in range(0,len(image_resized_test)):
    
    img_test.append(Resize(image_resized_test[i],width,height))
    
    final_test.append(resize(img_test[i],(256,256)))

    data_tensor_test.append(torch.from_numpy(np.array(final_test[i])).float())
    
    img_test_gt.append(Resize(image_resized_test_gt[i],width,height))
    
    final_test_gt.append(resize(img_test_gt[i],(256,256)))

    gt_tensor_test.append(torch.from_numpy(np.array(final_test_gt[i])).float())

# for k in data_tensor_train:
#     X_train=torch.FloatTensor(k)
    
X_train=torch.stack(data_tensor_train)
X_test=torch.stack(data_tensor_test)
Y_train=torch.stack(gt_tensor_train)
Y_test=torch.stack(gt_tensor_test)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class UNet(nn.Module):
    def Block_Contraction(self, in_channels, out_channels, kernel_size=3):
        
        
        Block = torch.nn.Sequential(
                    torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1),
                    torch.nn.ReLU())
        return Block
    
    def Block_expansion(self, in_channels, mid_channel, out_channels, kernel_size=3):
        
        Block = torch.nn.Sequential(
                    torch.nn.Conv2d(in_channels=in_channels, out_channels=mid_channel,kernel_size=kernel_size, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(in_channels=mid_channel, out_channels=mid_channel,kernel_size=kernel_size, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
                    )
        return  Block
    
    def Block_final(self, in_channels, mid_channel, out_channels, kernel_size=3):
        
        Block = torch.nn.Sequential(
                    torch.nn.Conv2d(in_channels=in_channels, out_channels=mid_channel, kernel_size=kernel_size, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(in_channels=mid_channel, out_channels=mid_channel,kernel_size=kernel_size, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(in_channels=mid_channel, out_channels=out_channels,kernel_size=1)
                    )
        return  Block
    
    def __init__(self, in_channel, out_channel,dropout=0.5):
        
        super(UNet, self).__init__()
        #Encode
        self.conv_encode1 = self.Block_Contraction(in_channels=in_channel, out_channels=32)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
       
        self.conv_encode2 = self.Block_Contraction(32, 64)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
     
        self.conv_encode3 = self.Block_Contraction(64, 128)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
   
        self.conv_encode4 = self.Block_Contraction(128, 256)
        self.conv_maxpool4 = torch.nn.MaxPool2d(kernel_size=2)
  
        # Bottleneck
        
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(in_channels=256, out_channels=512,kernel_size=3,padding=1),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3,padding=1),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(512),
                            torch.nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
                            )
        # Decode
        
        self.conv_decode4 = self.Block_expansion(768, 256, 384)
 
        self.conv_decode3 = self.Block_expansion(384, 128, 192)

        self.conv_decode2 = self.Block_expansion(192, 64, 96)

        self.couche_finale = self.Block_final(96, 32, out_channel)
  
                   
            
    def concatenate(self, first,second):
        
        return torch.cat((first,second),1)
   
    
    def forward(self, x):
        # Encoder
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        
        encode_block2 = self.conv_encode2(encode_pool1)
        
        encode_pool2 = self.conv_maxpool2(encode_block2)
        
        encode_block3 = self.conv_encode3(encode_pool2)
       
        encode_pool3 = self.conv_maxpool3(encode_block3)
        
        
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_maxpool4(encode_block4)
        
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool4)
        
        # Decoder
        decode_block4 = self.concatenate(bottleneck1, encode_block4)
        cat_layer3 = self.conv_decode4(decode_block4)
        decode_block3 = self.concatenate(cat_layer3, encode_block3)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.concatenate(cat_layer2, encode_block2)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.concatenate(cat_layer1, encode_block1)
        couche_finale = self.couche_finale(decode_block1)

        return  couche_finale
    
def init_weights(k):
    
    if isinstance(k, nn.Conv2d):
            
        nn.init.xavier_uniform_(k.weight)
        nn.init.zeros_(k.bias)    

Net=UNet(1,4)  
Net.apply(init_weights)  
Net.to(device)

#affichage de tout le modèle 
summary(Net,(1,256,256))

class LoadDataset_train(Dataset):
    

    def __init__(self):
        # load data
        self.X=X_train
        self.Y=Y_train
        self.len=X_train.shape[0]
        

    def __len__(self):
        
        # dataset's size
        return self.len
        

    def __getitem__(self, idx):
        X,Y = self.X[idx], self.Y[idx]
        # the position
        X = X.view(1,X.shape[0],X.shape[1])
        sample=X,Y
        return sample 
    
class LoadDataset_test(Dataset):
    

    def __init__(self):
        # load data
        self.X=X_test
        self.Y=Y_test
        self.len=X_test.shape[0]
        

    def __len__(self):
        
        # dataset's size
        return self.len
        

    def __getitem__(self, idx):
        X,Y = self.X[idx], self.Y[idx]
        # the position
        X = X.view(1,X.shape[0],X.shape[1])
        sample=X,Y
        return sample       
    
dataset=LoadDataset_train()   
trainloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)


dataset_test=LoadDataset_test()   
trainloader_test = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=4, shuffle=True, num_workers=2)



learning_rate=10e-3
#define the weights using in the loss function 
weights = torch.tensor([0.0,1.0,0.0,0.0]).to(device)  ########
#define a loss function 
criterion = nn.CrossEntropyLoss(weights)  
#define the optimizer aDAM 
optimizer=optim.Adam(Net.parameters(), lr=learning_rate , weight_decay=1e-8)

score=[]
n_epochs=10
loss_values=[]
loss_values_test=[]
for epoch in range(n_epochs):
    
    Net.train()
    running_loss = 0.0
    
    
    
    for i, data in enumerate(trainloader, 0):
        
            
        # get the inputs; data is a list of [inputs, labels]

        inputs, labels = data[0].to(device), data[1].to(device)
         
        
            
        optimizer.zero_grad()
        outputs=Net(inputs)
        loss=criterion(outputs,labels.long())
        running_loss+=loss.item()
        loss.backward()
        optimizer.step()
        print(i)
        

    loss_values.append(running_loss/len(trainloader))    
        
    # Net.eval()
    # running_loss_test=0.0
    # with torch.no_grad():
        
        
        
    #     for j, data_test in enumerate(trainloader_test,0):
            
        
    #         inputs_test, labels_test = data_test[0].to(device), data_test[1].to(device)
    #         outputs_test=Net(inputs_test)
    #         loss_test=criterion(outputs_test,labels_test.long())
    #         running_loss_test+=loss_test.item()
            
    
    # loss_values_test.append(running_loss_test/len(trainloader_test))
    
        
        
# regarder l'évolution de loss   
        
plt.plot(loss_values)
# plt.plot(loss_values_test)
plt.title("Evolution de la loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()
         

def replace_pixels(img):
    
    labelsTC = [x for x in range(256,256)]
    temp_mask = np.zeros((img.shape),dtype=np.int)

    '''

      traverse through the given array img array. and compare each point like below

      if point >= 0 && point<=20:

        temp_mask[i] = labelsTC[point]

    '''

    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
                if(img[i,j]>=1 and img[i,j]<=20):
                    temp_mask[i,j] = labelsTC[img[i,j]]



    return temp_mask  

      
for i in range(0,len(outputs)):
    image_imshow(inputs[i][0].cpu().numpy(),labels[i].cpu().numpy())
    
    score.append(dice_coeff(conversion(torch.argmax(outputs[i],0)),conversion(labels[i])))
    
    
# for i in range(0,len(outputs_test)):
    
#     image_imshow(labels_test[i].cpu().numpy(),torch.argmax(outputs_test[i],0).cpu().numpy())
                 
print('Finished Training')                  
im_imshow(torch.argmax(outputs[0],0).cpu().numpy())                 
preds=outputs.argmax(1)
print(torch.unique(preds))

THANK YOUUU !!!

The model itself seems to be alright, as I’m able to create all 5 classes using this random data:

model = UNet(1, 5).cuda()
data = torch.randn(10, 1, 224, 224).cuda()
target = torch.randint(0, 5, (10, 224, 224)).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

print(torch.unique(output.argmax(1)))

> epoch 0, loss 1.6101008653640747
epoch 1, loss 1.609655737876892
epoch 2, loss 1.6094974279403687
epoch 3, loss 1.6103358268737793
epoch 4, loss 1.609430193901062
epoch 5, loss 1.6094392538070679
epoch 6, loss 1.6094383001327515
epoch 7, loss 1.6094332933425903
epoch 8, loss 1.6094229221343994
epoch 9, loss 1.6094087362289429
tensor([0, 1, 2, 3, 4], device='cuda:0')

You should check, if the dice loss definition is correct for a multi-class segmentation.
If I’m not mistaken, it’s not used for the model training, so it shouldn’t influence the result.

PS: you can post code snippets by wrapping them into three backticks ```. I’ve edited your code for easier copy-pasting. :wink:

So sorry, i’m new in this forum, that’s why, i will be more carful another time

Thank you a lot, i will see
i think the problem is in my data and not in my model, i’m confused !!!
thank you for giving me some of your time ,have a good day :blush:

So sorry, i’m new in this forum, that’s why, i will be more carful another time

Thank you a lot, i will see
i think the problem is in my data and not in my model, i’m confused !!!
thank you for giving me some of your time ,have a good day :blush:

sir ,i have 5 classes on ground truth but my model predicts only 4 classes ,can you suggest me any ideas??

Are you dealing with an imbalanced dataset?
If so, you could e.g. oversample the minority class using a WeightedRandomSampler or use a class weighting in the cirterion.
If not, could you post the shape of your model output, the target, and the criterion you are using, and explain your use case a bit?

1 Like

outputs shape=torch.Size([8, 5, 128, 128])
target shape=torch.Size([8, 1, 128, 128])
i am using cross entropy criterion
my queestion is when i use torch.unique(outputs.argmax(1)) it gives [0,1,2,3,4] which are five classes but when i used saved model for prediction it only gives [0,1,2,3]
can you suggest me please!!

The target shape would raise a shape mismatch for nn.CrossEntropyLoss, so I guess dim1 is missing in your code.

Was your model predicting all 5 classes during the validation and is this issue only raised once you store and reload the model?

1 Like

i am posting here my code…
import os

 os.environ["CUDA_VISIBLE_DEVICES"]="1"
import sys
# sys.path.append('../src')
import numpy as np
import torch
import time 
from PIL import  Image
# import cv2
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from torch.autograd import Variable
import torch.utils.data as data
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from new_data import FishDataset
from last_model import UNet
from data_process.utils import label_mapping, RemoteSensingDataset
from sklearn.metrics import confusion_matrix, cohen_kappa_score


def img_transforms(img,label):
    # img, label = random_crop(img, label, crop_size)
    transform = transforms.Compose([
       # transforms.Resize(size=(128, 128)),
        # transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        ])
    img = transform(img)
    label = torch.from_numpy(label)
    # label=label.resize(128,128)
    return img, label

train_dataset = RemoteSensingDataset(True, img_transforms)
val_dataset = RemoteSensingDataset(False, img_transforms)
train_data = data.DataLoader(train_dataset,8, shuffle=True, num_workers=0)
val_data = data.DataLoader(val_dataset,8, num_workers=0)
    


model = UNet()
# print(model)
model.cuda()

criterion =  nn.CrossEntropyLoss() ##
optimizer =  torch.optim.SGD(model.parameters(), lr=0.1,momentum=0.9)
model_folder = os.path.abspath('E:/new_mission _data')
if not os.path.exists(model_folder):
    os.mkdir(model_folder)
model_path = os.path.join(model_folder, 'last_model5.pt')


hist = {'loss': [], 'jaccard': [], 'val_loss': [], 'val_jaccard': []}
num_epochs = 3
display_steps =50
best_jaccard = 0
tm0=time.time()

for epoch in range(num_epochs):
    print('Starting epoch {}/{}'.format(epoch+1, num_epochs))
    # train
    model.train()
    running_loss = 0.0
    running_jaccard = 0.0
    pass_count = 0
    for batch_idx,(img,label) in enumerate(train_data):
        pass_count += 1
        images = Variable(img.cuda())
        masks = Variable(label.cuda())
        optimizer.zero_grad()
        outputs = model(images)
        # print(torch.unique(outputs.argmax(1)))
        # pred = outputs.max(1)[1].squeeze().cpu().data.numpy() 
        loss =criterion(outputs,masks) 
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss.data.item()
        if batch_idx % display_steps == 0:
                    tm1 = time.time()
                    delta = int(tm1-tm0)
                    print('    %d'%(delta), end='')
                    print('batch {:>3}/{:>3} loss: {:.4f}'
                          .format(batch_idx+1, len(train_data),
                                  loss.item()))

                           
        # evalute
    print('Finished epoch {}, starting evaluation'.format(epoch+1))
    model.eval()
    val_running_loss = 0.0
    val_running_jaccard = 0.0
    for img, label in val_data:
        images = Variable(img.cuda())
        masks = Variable(label.cuda())        
        outputs = model(images)
        outputs=outputs.float()
        masks=masks.squeeze(1)
        loss = criterion(outputs, masks.long())
        val_running_loss += loss.data.item()
        # jac = jaccard(outputs.round(), masks)
        # val_running_jaccard += jac.data.item()

    train_loss = running_loss / len(train_data)
    train_jaccard = running_jaccard / len(train_data)
    # val_loss = val_running_loss / len(val_loader)
    # val_jaccard = val_running_jaccard / len(val_loader)
    
    hist['loss'].append(train_loss)
    # hist['jaccard'].append(train_jaccard)
    # hist['val_loss'].append(val_loss)
    # hist['val_jaccard'].append(val_jaccard)
    
    # if val_jaccard > best_jaccard:
        
    torch.save(model, model_path)
    torch.save(model.state_dict(), os.path.join(model_folder, 'unet_state_dict5.pt'))
    print('    ', end='')
    print('loss: {:.4f}'\
           .format(train_loss))

@ptrblck sir please correct me i have tried alot

Thanks for the code. Could you answer these open question from my previous posts first, please?

  • Are you dealing with an imbalanced dataset?
  • Is your model predicting all 5 classes during validation and just fails after storing and loading the trained state_dict?

yes sir i am dealing with imbalanced data and i assigned different weights but still got error as you mentioned "model predicting all 5 classes during validation and just fails after storing and loading the trained state_dict"

If your model is not predicting the minority class, you might need to increase the sample weights for these class sample and force the model to predict them.

dear sir,I assigned different weights accoriding to count of class, but my losses are like this ,please have a look on this varying loss…
Starting epoch 1/20
3batch 1/975 loss: 1.6033
185batch 151/975 loss: 1.4414
369batch 301/975 loss: 1.5746
552batch 451/975 loss: 1.3848
735batch 601/975 loss: 1.3494
922batch 751/975 loss: 1.4717
1218batch 901/975 loss: 1.3614
Finished epoch 1, starting evaluation
loss: 1.4235
Starting epoch 2/20
1417batch 1/975 loss: 1.4824
1600batch 151/975 loss: 1.3209
1783batch 301/975 loss: 1.2848
1965batch 451/975 loss: 1.0356
2150batch 601/975 loss: 1.2677
2340batch 751/975 loss: 1.1203
2522batch 901/975 loss: 1.3004
Finished epoch 2, starting evaluation
loss: 1.3725
Starting epoch 3/20
2712batch 1/975 loss: 1.1519
2892batch 151/975 loss: 1.4663
3074batch 301/975 loss: 1.6153
3254batch 451/975 loss: 1.3504
3434batch 601/975 loss: 1.1647
3614batch 751/975 loss: 1.2644
3794batch 901/975 loss: 1.3707
Finished epoch 3, starting evaluation
loss: 1.3314

The loss might increase while the accuracy might decrease due to the Accuracy Paradox. For an imbalanced dataset you could instead have a look at the confusion matrix and make sure the per-class accuracies are as expected.