Hi everyone,
I hope you can help me. I’m doing a project for the university and I’m having some problems. I’m using HydraPlusNet to solve a problem of PAR and I need to fine tune the net. In particular, from the original 26 classes I need to classify only 2, by training the net with a custom dataset.
During the training the train_loss and the validation_loss will not converge to a local minimum. In particular, the training phase will reach a minimum, i.e 0.1 on train_loss and 0.31 on val_loss but then both of the loss will start to encrease.
This is the code i use to train, that comes from the github repository of HydraPlusNet
import os
import torch
import torch.utils.data as data
from PIL import Image
import matplotlib.pyplot as plt
import torch.utils
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as scio
import torchvision.transforms as transforms
from lib.AF import AF
from lib.MNet import MNet
from lib.Hydraplus import HP
from lib import dataload
from tqdm import tqdm
from torch.autograd import Variable
import argparse
import logging
import pdb
import numpy as np
from visdom import Visdom
viz = Visdom()
win = viz.line(
Y=np.array([0.2]),
name=“1”,
opts=dict(title=“Training Loss”)
)
win2 = viz.line(
Y=np.array([0.2]),
name=“1”,
opts=dict(title=“Validation Loss”)
)
def freeze_layers(net):
“”“Function to freeze the first layers of the model.”“”
# List all the parameters
params = list(net.parameters())
total_layers = len(params)
freeze_until = total_layers // 2 # Freeze first half
for i, param in enumerate(params):
if i < freeze_until:
param.requires_grad = False
else:
param.requires_grad = True
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(‘-m’, help=“choose model”, choices=[‘MNet’, ‘AF1’, ‘AF2’, ‘AF3’, ‘HP’])
# pre-trained checkpoint
parser.add_argument('-r', dest='r', help="resume training", default=False)
parser.add_argument('-checkpoint', dest='checkpoint', help="load weight path", default=None)
parser.add_argument('-mpath', dest='mpath', help="load MNet weight path", default=None)
parser.add_argument('-af1path', dest='af1path', help="load AF1 weight path", default=None)
parser.add_argument('-af2path', dest='af2path', help="load AF2 weight path", default=None)
parser.add_argument('-af3path', dest='af3path', help="load AF3 weight path", default=None)
# training hyper-parameters
parser.add_argument('-nw', dest='nw', help="number of workers for dataloader",
default=0, type=int)
parser.add_argument('-bs', dest='bs', help="batch size",
default=100, type=int)
parser.add_argument('-lr', dest='lr', help="learning rate",
default=0.001, type=float)
parser.add_argument('-mGPUs', dest='mGPUs',
help='whether use multiple GPUs',
action='store_true')
args = parser.parse_args()
return args
Function to save checkpoint
def checkpoint_save(args_m, state_dict, epoch):
save_path = “./checkpoint/” + args_m + “epoch{}”.format(epoch)
torch.save(state_dict, save_path)
return save_path
Function to initialize weights
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
Function for early stopping
def validate(net, val_loader, criterion, checkpoint_path):
net.eval()
val_loss = 0.0
total_batches = len(val_loader)
net.load_state_dict(torch.load(checkpoint_path))
with torch.no_grad():
for i, (inputs, labels, _) in enumerate(val_loader):
inputs, labels = inputs.cuda(), labels.cuda()
outputs, confidence = net(inputs) # Get the primary output
labels_one_hot = F.one_hot(labels, num_classes=2).float()
val_loss += criterion(outputs, labels_one_hot).item()
#if i % 1000 == 0:
return val_loss / len(val_loader)
def main():
args = parse_args()
mytransform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.Resize((299, 299)),
transforms.ToTensor()]
)
# Load dataset
data_set = dataload.myImageFloder(
root="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/train/train/",
annotation_file="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/train/_annotations.csv",
transform=mytransform,
mode='train'
)
validation_set = dataload.myImageFloder(
root="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/valid/valid/",
annotation_file="/home/alfonso/Desktop/Dumping_Garbage_Dataset.v8i.retinanet/valid/_annotations.csv",
transform=mytransform,
mode='valid'
)
# Create data loaders
train_loader = torch.utils.data.DataLoader(data_set, batch_size=args.bs, shuffle=True, num_workers=args.nw)
val_loader = torch.utils.data.DataLoader(validation_set, batch_size=args.bs, shuffle=False, num_workers=args.nw)
print('Image numbers {}'.format(len(data_set)))
# Define the training model
if args.m == 'MNet':
net = MNet()
if not args.r:
net.apply(weight_init)
freeze_layers(net)
elif 'AF' in args.m:
net = AF(af_name=args.m)
if not args.r:
net.MNet.load_state_dict(torch.load(args.mpath))
for param in net.MNet.parameters():
param.requires_grad = False
elif args.m == 'HP':
net = HP()
if not args.r:
net.MNet.load_state_dict(torch.load(args.mpath))
net.AF1.load_state_dict(torch.load(args.af1path))
net.AF2.load_state_dict(torch.load(args.af2path))
net.AF3.load_state_dict(torch.load(args.af3path))
for param in net.MNet.parameters():
param.requires_grad = False
for param in net.AF1.parameters():
param.requires_grad = False
for param in net.AF2.parameters():
param.requires_grad = False
for param in net.AF3.parameters():
param.requires_grad = False
# Resume training and load the checkpoint from last training
start_epoch = 1
if args.r:
net.load_state_dict(torch.load(args.checkpoint))
numeric_filter = filter(str.isdigit, args.checkpoint)
numeric_string = "".join(numeric_filter)
start_epoch = int(numeric_string) + 1
net.cuda()
if args.mGPUs:
net = nn.DataParallel(net)
net.train()
# Loss function and optimizer
loss_cls_weight = [1, 1] # Carrying, NotCarrying
weight = torch.Tensor(loss_cls_weight)
criterion = nn.BCEWithLogitsLoss(weight=weight)
criterion.cuda()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
# Initialize logging
logging.basicConfig(level=logging.DEBUG, filename='./result/training_log/' + args.m + '.log',
datefmt='%Y/%m/%d %H:%M:%S', format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)
# Early stopping parameters
best_val_loss = float('inf')
patience = 25
epochs_without_improvement = 0
running_loss = 0.0
val_loss=0.0
# Training loop
for epoch in range(start_epoch, 100):
total_batches = len(train_loader)
# Wrap train_loader with tqdm for progress bar
with tqdm(total=total_batches, desc=f'Epoch {epoch}', ncols=100) as pbar:
for i, data in enumerate(train_loader, 0):
inputs, labels, _ = data
inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
optimizer.zero_grad()
outputs,confidence= net(inputs) # Ottieni solo x0 dalla rete
labels_one_hot = F.one_hot(labels, num_classes=2).float()
loss = criterion(outputs, labels_one_hot)
loss.backward()
optimizer.step()
running_loss = loss.data.item()
pbar.set_postfix(loss=loss.item(), val_loss=val_loss) # Update progress bar with current losses
pbar.update() # Update progress bar
#if i % 1000 == 0:
logger.info('[%d %5d] loss: %.6f' % (epoch, i + 1, loss))
viz.line(
X=np.array([epoch + i / 5000.0]),
Y=np.array([running_loss]),
win=win,
update="append"
)
# Validation loop
checkpoint_path=checkpoint_save(args.m, net.state_dict(), epoch)
val_loss = validate(net, val_loader, criterion,checkpoint_path)
scheduler.step(val_loss)
viz.line(
X=np.array([epoch + i / 5000.0]),
Y=np.array([val_loss]),
win=win2,
update="append")
logger.info('Validation Loss: %.6f' % val_loss)
pbar.set_postfix(loss=loss.item(), val_loss=val_loss) # Update progress bar with current losses
pbar.update() # Update progress bar
if i % 1000 == 0:
logger.info('[%d %5d] loss: %.6f' % (epoch, i + 1, loss))
# Check for early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
print("\nCheckpoint Saved at epoch "+ str(epoch) + "\n")
checkpoint_path=checkpoint_save(args.m, net.state_dict(), epoch)
else:
epochs_without_improvement += 1
if epochs_without_improvement >= 10 :
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 10**-5
print("\nLearning Rate decreased\n")
if epochs_without_improvement >= patience:
logger.info('Early stopping at epoch %d' % epoch)
break
# Save checkpoint
#if epoch % 1 == 0:
#if args.mGPUs:
# checkpoint_save(args.m, net.module.state_dict(), epoch)
#else:
# checkpoint_save(args.m, net.state_dict(), epoch)
if name == ‘main’:
main()
and this is the Net
import torch
import torch.nn as nn
import torch.nn.functional as F
class MNet(nn.Module):
def init(self, num_classes=2, feat_out=False):
super(MNet, self).init()
self.Conv2d_1_7x7_s2 = BasicConv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.Conv2d_2_1x1 = BasicConv2d(32, 32, kernel_size=1)
self.Conv2d_3_3x3 = BasicConv2d(32, 96, kernel_size=3, padding=1)
self.incept_block_1 = InceptBlock1()
self.incept_block_2 = InceptBlock2()
self.incept_block_3 = InceptBlock3()
self.final_fc = nn.Linear(512, num_classes) # Aggiornato il numero di classi qui
self.feat_out = feat_out # output intermediate features for AF branches
# Inizializzazione dei pesi
self.apply(weight_init)
def forward(self, x):
# 3 x 299 x 299
x = self.Conv2d_1_7x7_s2(x)
# 32 x 155 x 155
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 32 x 74 x74
x = self.Conv2d_2_1x1(x)
# 32 x 74 x74
x = self.Conv2d_3_3x3(x)
# 96 x 74 x 74
x0 = F.max_pool2d(x, kernel_size=3, stride=2)
# x0=96 x 36 x 36
x1 = self.incept_block_1(x0)
# x1=256 x 18 x 18
x2 = self.incept_block_2(x1)
# x2 = 502 x 9 x9
x3 = self.incept_block_3(x2)
# x3 = 512 x 9 x9
x = F.avg_pool2d(x3, kernel_size=9, stride=1)
# 512 x1 x1
x = F.dropout(x, training=self.training)
# 1 x 1 x 512
x = x.view(x.size(0), -1)
# 512
pred_class = self.final_fc(x)
confidence = F.softmax(pred_class,dim=1)
if self.feat_out:
return x0, x1, x2, x3
else:
return pred_class,confidence
class BasicConv2d(nn.Module):
def init(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).init()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
class InceptBlock1(nn.Module):
def init(self):
super(InceptBlock1, self).init()
self.module_a = InceptionA(in_channels=96, b_1x1_out=32, b_5x5_1_out=32, b_5x5_2_out=32,
b_3x3_1_out=32, b_3x3_2_out=48, b_3x3_3_out=48, b_pool_out=16) # C_out =128
self.module_b = InceptionB(in_channels=128, b_1x1_1_out=64, b_1x1_2_out=80,
b_3x3_1_out=32, b_3x3_2_out=48, b_3x3_3_out=48)
def forward(self, x):
x = self.module_a(x)
x = self.module_b(x)
return x
class InceptBlock2(nn.Module):
def init(self):
super(InceptBlock2, self).init()
self.module_a = InceptionA(in_channels=256, b_1x1_out=112, b_5x5_1_out=32, b_5x5_2_out=48,
b_3x3_1_out=48, b_3x3_2_out=64, b_3x3_3_out=64, b_pool_out=64) # C_out = 288
self.module_b = InceptionB(in_channels=288, b_1x1_1_out=64, b_1x1_2_out=86,
b_3x3_1_out=96, b_3x3_2_out=128, b_3x3_3_out=128) # C_out = 502??!
def forward(self, x):
x = self.module_a(x)
x = self.module_b(x)
return x
class InceptBlock3(nn.Module):
def init(self):
super(InceptBlock3, self).init()
self.module_a = InceptionA(in_channels=502, b_1x1_out=176, b_5x5_1_out=96, b_5x5_2_out=160,
b_3x3_1_out=80, b_3x3_2_out=112, b_3x3_3_out=112, b_pool_out=64)
self.module_b = InceptionA(in_channels=512, b_1x1_out=176, b_5x5_1_out=96, b_5x5_2_out=160,
b_3x3_1_out=96, b_3x3_2_out=112, b_3x3_3_out=112, b_pool_out=64)
def forward(self, x):
x = self.module_a(x)
x = self.module_b(x)
return x
class InceptionA(nn.Module):
def init(
self,
in_channels,
b_1x1_out,
b_5x5_1_out,
b_5x5_2_out,
b_3x3_1_out,
b_3x3_2_out,
b_3x3_3_out,
b_pool_out
):
super(InceptionA, self).init()
self.branch1x1 = BasicConv2d(in_channels, b_1x1_out, kernel_size=1) # H_out = H_in
self.branch5x5_1 = BasicConv2d(in_channels, b_5x5_1_out, kernel_size=1)
self.branch5x5_2 = BasicConv2d(b_5x5_1_out, b_5x5_2_out, kernel_size=3, padding=1) # H_out = H_in
self.branch3x3dbl_1 = BasicConv2d(in_channels, b_3x3_1_out, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(b_3x3_1_out, b_3x3_2_out, kernel_size=3, padding=1)
self.branch3x3dbl_3 = BasicConv2d(b_3x3_2_out, b_3x3_3_out, kernel_size=3, padding=1)
self.branch_pool = BasicConv2d(in_channels, b_pool_out, kernel_size=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.avg_pool2d(x, kernel_size=1, stride=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class InceptionB(nn.Module):
def init(
self,
in_channels,
b_1x1_1_out,
b_1x1_2_out,
b_3x3_1_out,
b_3x3_2_out,
b_3x3_3_out
):
super(InceptionB, self).init()
self.branch1x1_1 = BasicConv2d(in_channels, b_1x1_1_out, kernel_size=1)
self.branch1x1_2 = BasicConv2d(b_1x1_1_out, b_1x1_2_out, kernel_size=3, stride=2, padding=1)
self.branch3x3dbl_1 = BasicConv2d(in_channels, b_3x3_1_out, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(b_3x3_1_out, b_3x3_2_out, kernel_size=3, padding=1)
self.branch3x3dbl_3 = BasicConv2d(b_3x3_2_out, b_3x3_3_out, kernel_size=3, stride=2, padding=1)
def forward(self, x):
branch3x3 = self.branch1x1_1(x)
branch3x3 = self.branch1x1_2(branch3x3)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
outputs = [branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)