I am struggling to integrate kFold cross validation to my script
the script I am working with i use three set of data(training , validation and inference ) and the functions are defined in the main_worker and i want to use cross validation to save the best model and test it in the inference function i defined (precision ,recall top-1 accuracy and confusin matrix ) and visualize the training/validation curves
can someone help me please ?
this is the script I am working with
import necessary librairies
.....
def json_serial(obj):
if isinstance(obj, Path):
return str(obj)
def get_opt():
..............
code here
return opt
def resume_model(resume_path, arch, model):
.......code here
return model
def resume_train_utils(resume_path, begin_epoch, optimizer, scheduler):
....... code here
return begin_epoch, optimizer, scheduler
def get_normalize_method(mean, std, no_mean_norm, no_std_norm):
return Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
def get_train_utils(opt, model_parameters):
.... list of spatial and temporal transformation
for the training
train_data = get_training_data(opt.video_path, opt.annotation_path,
opt.dataset, opt.input_type, opt.file_type,
spatial_transform, temporal_transform)
if opt.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_data)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=opt.batch_size,
shuffle=(train_sampler is None),
num_workers=opt.n_threads,
pin_memory=True,
sampler=train_sampler,
worker_init_fn=worker_init_fn)
...............rest of the code
return (train_loader, train_sampler, train_logger, train_batch_logger,
optimizer, scheduler)
def get_val_utils(opt):
..... code of the spatial and temporal transformation for validation
val_data, collate_fn = get_validation_data(opt.video_path,
opt.annotation_path, opt.dataset,
opt.input_type, opt.file_type,
spatial_transform,
temporal_transform)
if opt.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_data, shuffle=False)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(val_data,
batch_size=(opt.batch_size //
opt.n_val_samples),
shuffle=False,
num_workers=opt.n_threads,
pin_memory=True,
sampler=val_sampler,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn)
........code
return val_loader, val_logger
def get_inference_utils(opt):
...... transformation for the inference
inference_data, collate_fn = get_inference_data(
opt.video_path, opt.annotation_path, opt.dataset, opt.input_type,
opt.file_type, opt.inference_subset, spatial_transform,
temporal_transform)
inference_loader = torch.utils.data.DataLoader(
inference_data,
batch_size=opt.inference_batch_size,
shuffle=False,
num_workers=opt.n_threads,
pin_memory=True,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn)
return inference_loader, inference_data.class_names
def save_checkpoint(save_file_path, epoch, arch, model, optimizer, scheduler):
.....code
torch.save(save_states, save_file_path)
def main_worker(index, opt):
random.seed(opt.manual_seed)
np.random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
if index >= 0 and opt.device.type == 'cuda':
opt.device = torch.device(f'cuda:{index}')
if opt.distributed:
opt.dist_rank = opt.dist_rank * opt.ngpus_per_node + index
dist.init_process_group(backend='nccl',
init_method=opt.dist_url,
world_size=opt.world_size,
rank=opt.dist_rank)
opt.batch_size = int(opt.batch_size / opt.ngpus_per_node)
opt.n_threads = int(
(opt.n_threads + opt.ngpus_per_node - 1) / opt.ngpus_per_node)
opt.is_master_node = not opt.distributed or opt.dist_rank == 0
model = generate_model(opt)
if opt.batchnorm_sync:
assert opt.distributed, 'SyncBatchNorm only supports DistributedDataParallel.'
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if opt.pretrain_path:
model = load_pretrained_model(model, opt.pretrain_path, opt.model,
opt.n_finetune_classes)
if opt.resume_path is not None:
model = resume_model(opt.resume_path, opt.arch, model)
model = make_data_parallel(model, opt.distributed, opt.device)
if opt.pretrain_path:
parameters = get_fine_tuning_parameters(model, opt.ft_begin_module)
else:
parameters = model.parameters()
if opt.is_master_node:
print(model)
criterion = CrossEntropyLoss().to(opt.device)
if not opt.no_train:
(train_loader, train_sampler, train_logger, train_batch_logger,
optimizer, scheduler) = get_train_utils(opt, parameters)
if opt.resume_path is not None:
opt.begin_epoch, optimizer, scheduler = resume_train_utils(
opt.resume_path, opt.begin_epoch, optimizer, scheduler)
if opt.overwrite_milestones:
scheduler.milestones = opt.multistep_milestones
if not opt.no_val:
val_loader, val_logger = get_val_utils(opt)
if opt.tensorboard and opt.is_master_node:
from torch.utils.tensorboard import SummaryWriter
if opt.begin_epoch == 1:
tb_writer = SummaryWriter(log_dir=opt.result_path)
else:
tb_writer = SummaryWriter(log_dir=opt.result_path,
purge_step=opt.begin_epoch)
else:
tb_writer = None
prev_val_loss = None
#--------------------------------
train_losses = []
train_accuracies = []
val_losses=[]
val_accuracies=[]
#--------------------------------
for i in range(opt.begin_epoch, opt.n_epochs + 1):
if not opt.no_train:
if opt.distributed:
train_sampler.set_epoch(i)
current_lr = get_lr(optimizer)
train_loss,train_accuracy=train_epoch(i, train_loader, model, criterion, optimizer,
opt.device, current_lr, train_logger,
train_batch_logger, tb_writer, opt.distributed)
if i % opt.checkpoint == 0 and opt.is_master_node:
save_file_path = opt.result_path / 'save_{}.pth'.format(i)
save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
scheduler)
#-----------------------------
train_losses.append(train_loss)
train_accuracies.append(train_accuracy)
#------------------------------
if not opt.no_val:
prev_val_loss,prev_val_accuracy = val_epoch(i, val_loader, model, criterion,
opt.device, val_logger, tb_writer,
opt.distributed)
#---------------------------------
val_losses.append(prev_val_loss)
val_accuracies.append(prev_val_accuracy)
#---------------------------------
if not opt.no_train and opt.lr_scheduler == 'multistep':
scheduler.step()
elif not opt.no_train and opt.lr_scheduler == 'plateau':
scheduler.step(prev_val_loss)
plt.figure(figsize=(10, 7))
plt.plot(range(opt.begin_epoch, opt.n_epochs + 1), val_accuracies, color ='r', linestyle='-',label=' validation Accuracy')
plt.plot(range(opt.begin_epoch, opt.n_epochs + 1), train_accuracies, color ='b',linestyle='-',label=' Train Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig(opt.result_path / 'accuracy_curve.png')
plt.close()
plt.figure(figsize=(10, 7))
plt.plot(range(opt.begin_epoch, opt.n_epochs + 1),train_losses, color ='b',linestyle='-', label='train Loss')
plt.plot(range(opt.begin_epoch, opt.n_epochs + 1), val_losses, color ='r',linestyle='-', label='validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig(opt.result_path / 'loss_curve.png')
plt.close()
if opt.inference:
inference_loader, inference_class_names = get_inference_utils(opt)
inference_result_path = opt.result_path / '{}.json'.format(
opt.inference_subset)
inference.inference(inference_loader, model, inference_result_path,
inference_class_names, opt.inference_no_average,opt.output_topk)
if __name__ == '__main__':
opt = get_opt()
opt.device = torch.device('cpu' if opt.no_cuda else 'cuda')
if not opt.no_cuda:
cudnn.benchmark = True
if opt.accimage:
torchvision.set_image_backend('accimage')
opt.ngpus_per_node = torch.cuda.device_count()
if opt.distributed:
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(main_worker, nprocs=opt.ngpus_per_node, args=(opt,))
else:
main_worker(-1, opt)