The train_loss and the valid_loss will not converge to a local minimum

Hi everyone,
I hope you can help me. I’m doing a project for the university and I’m having some problems. I’m using HydraPlusNet to solve a problem of PAR and I need to fine tune the net. In particular, from the original 26 classes I need to classify only 2, by training the net with a custom dataset.
During the training the train_loss and the validation_loss will not converge to a local minimum. In particular, the training phase will reach a minimum, i.e 0.1 on train_loss and 0.31 on val_loss but then both of the loss will start to encrease.

This is the code i use to train, that comes from the github repository of HydraPlusNet
import os
import torch
import torch.utils.data as data
from PIL import Image
import matplotlib.pyplot as plt
import torch.utils
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as scio
import torchvision.transforms as transforms

from lib.AF import AF
from lib.MNet import MNet
from lib.Hydraplus import HP
from lib import dataload
from tqdm import tqdm

from torch.autograd import Variable
import argparse
import logging
import pdb
import numpy as np
from visdom import Visdom

viz = Visdom()
win = viz.line(
Y=np.array([0.2]),
name=“1”,
opts=dict(title=“Training Loss”)
)
win2 = viz.line(
Y=np.array([0.2]),
name=“1”,
opts=dict(title=“Validation Loss”)
)

def freeze_layers(net):
“”“Function to freeze the first layers of the model.”“”
# List all the parameters
params = list(net.parameters())
total_layers = len(params)
freeze_until = total_layers // 2 # Freeze first half

for i, param in enumerate(params):
    if i < freeze_until:
        param.requires_grad = False
    else:
        param.requires_grad = True

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(‘-m’, help=“choose model”, choices=[‘MNet’, ‘AF1’, ‘AF2’, ‘AF3’, ‘HP’])

# pre-trained checkpoint
parser.add_argument('-r', dest='r', help="resume training", default=False)
parser.add_argument('-checkpoint', dest='checkpoint', help="load weight path", default=None)
parser.add_argument('-mpath', dest='mpath', help="load MNet weight path", default=None)
parser.add_argument('-af1path', dest='af1path', help="load AF1 weight path", default=None)
parser.add_argument('-af2path', dest='af2path', help="load AF2 weight path", default=None)
parser.add_argument('-af3path', dest='af3path', help="load AF3 weight path", default=None)

# training hyper-parameters
parser.add_argument('-nw', dest='nw', help="number of workers for dataloader",
                    default=0, type=int)
parser.add_argument('-bs', dest='bs', help="batch size",
                    default=100, type=int)
parser.add_argument('-lr', dest='lr', help="learning rate",
                    default=0.001, type=float)
parser.add_argument('-mGPUs', dest='mGPUs',
                    help='whether use multiple GPUs',
                    action='store_true')

args = parser.parse_args()

return args

Function to save checkpoint

def checkpoint_save(args_m, state_dict, epoch):
save_path = “./checkpoint/” + args_m + “epoch{}”.format(epoch)
torch.save(state_dict, save_path)
return save_path

Function to initialize weights

def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)

Function for early stopping

def validate(net, val_loader, criterion, checkpoint_path):
net.eval()
val_loss = 0.0
total_batches = len(val_loader)
net.load_state_dict(torch.load(checkpoint_path))
with torch.no_grad():
for i, (inputs, labels, _) in enumerate(val_loader):
inputs, labels = inputs.cuda(), labels.cuda()
outputs, confidence = net(inputs) # Get the primary output
labels_one_hot = F.one_hot(labels, num_classes=2).float()
val_loss += criterion(outputs, labels_one_hot).item()
#if i % 1000 == 0:
return val_loss / len(val_loader)

def main():
args = parse_args()

mytransform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.Resize((299, 299)),
     transforms.ToTensor()]
)

# Load dataset
data_set = dataload.myImageFloder(
    root="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/train/train/",
    annotation_file="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/train/_annotations.csv",
    transform=mytransform,
    mode='train'
)
validation_set = dataload.myImageFloder(
    root="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/valid/valid/",
    annotation_file="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/valid/_annotations.csv",
    transform=mytransform,
    mode='valid'
)


# Create data loaders
train_loader = torch.utils.data.DataLoader(data_set, batch_size=args.bs, shuffle=True, num_workers=args.nw)
val_loader = torch.utils.data.DataLoader(validation_set, batch_size=args.bs, shuffle=False, num_workers=args.nw)

print('Image numbers {}'.format(len(data_set)))

# Define the training model
if args.m == 'MNet':
    net = MNet()
    if not args.r:
        net.apply(weight_init)
    freeze_layers(net)
elif 'AF' in args.m:
    net = AF(af_name=args.m)
    if not args.r:
        net.MNet.load_state_dict(torch.load(args.mpath))
    for param in net.MNet.parameters():
        param.requires_grad = False
elif args.m == 'HP':
    net = HP()
    if not args.r:
        net.MNet.load_state_dict(torch.load(args.mpath))
        net.AF1.load_state_dict(torch.load(args.af1path))
        net.AF2.load_state_dict(torch.load(args.af2path))
        net.AF3.load_state_dict(torch.load(args.af3path))
    for param in net.MNet.parameters():
        param.requires_grad = False
    for param in net.AF1.parameters():
        param.requires_grad = False
    for param in net.AF2.parameters():
        param.requires_grad = False
    for param in net.AF3.parameters():
        param.requires_grad = False

# Resume training and load the checkpoint from last training
start_epoch = 1
if args.r:
    net.load_state_dict(torch.load(args.checkpoint))
    numeric_filter = filter(str.isdigit, args.checkpoint)
    numeric_string = "".join(numeric_filter)
    start_epoch = int(numeric_string) + 1

net.cuda()
if args.mGPUs:
    net = nn.DataParallel(net)

net.train()

# Loss function and optimizer
loss_cls_weight = [1, 1]  # Carrying, NotCarrying
weight = torch.Tensor(loss_cls_weight)
criterion = nn.BCEWithLogitsLoss(weight=weight)
criterion.cuda()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9,  weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

# Initialize logging
logging.basicConfig(level=logging.DEBUG, filename='./result/training_log/' + args.m + '.log',
                    datefmt='%Y/%m/%d %H:%M:%S', format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)

# Early stopping parameters
best_val_loss = float('inf')
patience = 25
epochs_without_improvement = 0
running_loss = 0.0
val_loss=0.0
# Training loop
for epoch in range(start_epoch, 100):
    total_batches = len(train_loader)

     # Wrap train_loader with tqdm for progress bar
    
    with tqdm(total=total_batches, desc=f'Epoch {epoch}', ncols=100) as pbar:
        for i, data in enumerate(train_loader, 0):
            inputs, labels, _ = data
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            optimizer.zero_grad()
            outputs,confidence= net(inputs)  # Ottieni solo x0 dalla rete
            labels_one_hot = F.one_hot(labels, num_classes=2).float()
            loss = criterion(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()
            running_loss = loss.data.item()
            pbar.set_postfix(loss=loss.item(), val_loss=val_loss)  # Update progress bar with current losses
            pbar.update()  # Update progress bar


            #if i % 1000 == 0:
    logger.info('[%d %5d] loss: %.6f' % (epoch, i + 1, loss))
    viz.line(
                X=np.array([epoch + i / 5000.0]),
                Y=np.array([running_loss]),
                win=win,
                update="append"
            )
        # Validation loop
    checkpoint_path=checkpoint_save(args.m, net.state_dict(), epoch)
    val_loss = validate(net, val_loader, criterion,checkpoint_path)
    scheduler.step(val_loss)
    viz.line(
                X=np.array([epoch + i / 5000.0]),
                Y=np.array([val_loss]),
                win=win2,
                update="append")
    logger.info('Validation Loss: %.6f' % val_loss)
    pbar.set_postfix(loss=loss.item(), val_loss=val_loss)  # Update progress bar with current losses
    pbar.update()  # Update progress bar
    if i % 1000 == 0:
        logger.info('[%d %5d] loss: %.6f' % (epoch, i + 1, loss))
        
    # Check for early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        print("\nCheckpoint Saved at epoch "+ str(epoch) + "\n")
        checkpoint_path=checkpoint_save(args.m, net.state_dict(), epoch)
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= 10 :
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 10**-5
                print("\nLearning Rate decreased\n")
    if epochs_without_improvement >= patience:
        logger.info('Early stopping at epoch %d' % epoch)
        break

    # Save checkpoint
    #if epoch % 1 == 0:
        #if args.mGPUs:
        #    checkpoint_save(args.m, net.module.state_dict(), epoch)
        #else:
        #    checkpoint_save(args.m, net.state_dict(), epoch)

if name == ‘main’:
main()

and this is the Net

import torch
import torch.nn as nn
import torch.nn.functional as F

class MNet(nn.Module):
def init(self, num_classes=2, feat_out=False):
super(MNet, self).init()
self.Conv2d_1_7x7_s2 = BasicConv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.Conv2d_2_1x1 = BasicConv2d(32, 32, kernel_size=1)
self.Conv2d_3_3x3 = BasicConv2d(32, 96, kernel_size=3, padding=1)

    self.incept_block_1 = InceptBlock1()
    self.incept_block_2 = InceptBlock2()
    self.incept_block_3 = InceptBlock3()

    self.final_fc = nn.Linear(512, num_classes)  # Aggiornato il numero di classi qui

    self.feat_out = feat_out  # output intermediate features for AF branches
    
    # Inizializzazione dei pesi
    self.apply(weight_init)

def forward(self, x):
    # 3 x 299 x 299
    x = self.Conv2d_1_7x7_s2(x)
    # 32 x 155 x 155
    x = F.max_pool2d(x, kernel_size=3, stride=2)
    # 32 x 74 x74
    x = self.Conv2d_2_1x1(x)
    # 32 x 74 x74
    x = self.Conv2d_3_3x3(x)
    # 96 x 74 x 74
    x0 = F.max_pool2d(x, kernel_size=3, stride=2)
    # x0=96 x 36 x 36

    x1 = self.incept_block_1(x0)
    # x1=256 x 18 x 18
    x2 = self.incept_block_2(x1)
    # x2 = 502 x 9 x9
    x3 = self.incept_block_3(x2)
    # x3 = 512 x 9 x9

    x = F.avg_pool2d(x3, kernel_size=9, stride=1)
    # 512 x1 x1
    x = F.dropout(x, training=self.training)
    # 1 x 1 x 512
    x = x.view(x.size(0), -1)
    # 512
    pred_class = self.final_fc(x)

    confidence = F.softmax(pred_class,dim=1)
    if self.feat_out:
        return x0, x1, x2, x3
    else:
        return pred_class,confidence

class BasicConv2d(nn.Module):
def init(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).init()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x, inplace=True)

class InceptBlock1(nn.Module):
def init(self):
super(InceptBlock1, self).init()
self.module_a = InceptionA(in_channels=96, b_1x1_out=32, b_5x5_1_out=32, b_5x5_2_out=32,
b_3x3_1_out=32, b_3x3_2_out=48, b_3x3_3_out=48, b_pool_out=16) # C_out =128
self.module_b = InceptionB(in_channels=128, b_1x1_1_out=64, b_1x1_2_out=80,
b_3x3_1_out=32, b_3x3_2_out=48, b_3x3_3_out=48)

def forward(self, x):
    x = self.module_a(x)
    x = self.module_b(x)
    return x

class InceptBlock2(nn.Module):
def init(self):
super(InceptBlock2, self).init()
self.module_a = InceptionA(in_channels=256, b_1x1_out=112, b_5x5_1_out=32, b_5x5_2_out=48,
b_3x3_1_out=48, b_3x3_2_out=64, b_3x3_3_out=64, b_pool_out=64) # C_out = 288
self.module_b = InceptionB(in_channels=288, b_1x1_1_out=64, b_1x1_2_out=86,
b_3x3_1_out=96, b_3x3_2_out=128, b_3x3_3_out=128) # C_out = 502??!

def forward(self, x):
    x = self.module_a(x)
    x = self.module_b(x)
    return x

class InceptBlock3(nn.Module):
def init(self):
super(InceptBlock3, self).init()
self.module_a = InceptionA(in_channels=502, b_1x1_out=176, b_5x5_1_out=96, b_5x5_2_out=160,
b_3x3_1_out=80, b_3x3_2_out=112, b_3x3_3_out=112, b_pool_out=64)
self.module_b = InceptionA(in_channels=512, b_1x1_out=176, b_5x5_1_out=96, b_5x5_2_out=160,
b_3x3_1_out=96, b_3x3_2_out=112, b_3x3_3_out=112, b_pool_out=64)

def forward(self, x):
    x = self.module_a(x)
    x = self.module_b(x)
    return x

class InceptionA(nn.Module):
def init(
self,
in_channels,
b_1x1_out,
b_5x5_1_out,
b_5x5_2_out,
b_3x3_1_out,
b_3x3_2_out,
b_3x3_3_out,
b_pool_out
):
super(InceptionA, self).init()
self.branch1x1 = BasicConv2d(in_channels, b_1x1_out, kernel_size=1) # H_out = H_in

    self.branch5x5_1 = BasicConv2d(in_channels, b_5x5_1_out, kernel_size=1)
    self.branch5x5_2 = BasicConv2d(b_5x5_1_out, b_5x5_2_out, kernel_size=3, padding=1)  # H_out = H_in

    self.branch3x3dbl_1 = BasicConv2d(in_channels, b_3x3_1_out, kernel_size=1)
    self.branch3x3dbl_2 = BasicConv2d(b_3x3_1_out, b_3x3_2_out, kernel_size=3, padding=1)
    self.branch3x3dbl_3 = BasicConv2d(b_3x3_2_out, b_3x3_3_out, kernel_size=3, padding=1)

    self.branch_pool = BasicConv2d(in_channels, b_pool_out, kernel_size=1)

def forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch5x5 = self.branch5x5_1(x)
    branch5x5 = self.branch5x5_2(branch5x5)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=1, stride=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
    return torch.cat(outputs, 1)

class InceptionB(nn.Module):
def init(
self,
in_channels,
b_1x1_1_out,
b_1x1_2_out,
b_3x3_1_out,
b_3x3_2_out,
b_3x3_3_out
):
super(InceptionB, self).init()
self.branch1x1_1 = BasicConv2d(in_channels, b_1x1_1_out, kernel_size=1)
self.branch1x1_2 = BasicConv2d(b_1x1_1_out, b_1x1_2_out, kernel_size=3, stride=2, padding=1)

    self.branch3x3dbl_1 = BasicConv2d(in_channels, b_3x3_1_out, kernel_size=1)
    self.branch3x3dbl_2 = BasicConv2d(b_3x3_1_out, b_3x3_2_out, kernel_size=3, padding=1)
    self.branch3x3dbl_3 = BasicConv2d(b_3x3_2_out, b_3x3_3_out, kernel_size=3, stride=2, padding=1)

def forward(self, x):
    branch3x3 = self.branch1x1_1(x)
    branch3x3 = self.branch1x1_2(branch3x3)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)

    outputs = [branch3x3, branch3x3dbl, branch_pool]
    return torch.cat(outputs, 1)

def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)