Process freezes when I change dataset object during training with distributed data parallel

The entire process is freezing with gpu utilization 100% after line 386 has been invoked. The distributed backend is ‘nccl’. PFB the main function.

def main(args):
torch.backends.cudnn.enabled = True

utils.init_distributed_mode(args)
# print('git:\n  {}\n'.format(utils.get_sha()))

torch.autograd.set_detect_anomaly(True)
print(args)

device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

model, criterion, postprocessors = build_model(args) 
model.to(device)

model_without_ddp = model
if args.distributed:
    # process_groups[0 if utils.get_rank() <= 3 else 1]
    # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.gpu], find_unused_parameters=True)
    model_without_ddp = model.module

n_parameters = sum(p.numel() for p in model.parameters()
                   if p.requires_grad)
print('number of params:', n_parameters)

if args.stage == 1:
    for name, value in model_without_ddp.named_parameters():
        if 'graph_classifier' in name:
            value.requires_grad = False
        # if 'linear_classifier' in name:
        #   value.requires_grad = False

    learned_params = filter(lambda p: p.requires_grad,
                            model_without_ddp.parameters())
    learned_params = list(learned_params)

elif args.stage == 2:
    for name, value in model_without_ddp.named_parameters():
        if 'linear_classifier' in name:
            value.requires_grad = False
        if 'con_head' in name and args.supcon_loss_coef == 0 :
            value.requires_grad = False

    learned_params = filter(lambda p: p.requires_grad,
                            model_without_ddp.parameters())
    learned_params = list(learned_params)

# param_dicts = [
#     {"params": [p for n, p in model_without_ddp.named_parameters() if p.requires_grad]}
# ]

# optim = torch.optim.AdamW(learned_params,
#                               lr=args.lr,
#                               weight_decay=args.weight_decay) 
# if args.distributed:
#   optim = ZeroRedundancyOptimizer(learned_params,optimizer_class=torch.optim.AdamW,lr=args.lr,weight_decay=args.weight_decay)
# else:
# if args.stage == 1:
optim = torch.optim.SGD(learned_params, lr=args.lr, 
        momentum=args.momentum,nesterov=True, weight_decay=args.weight_decay)
# else:
#   optim = torch.optim.AdamW(learned_params,
                              # lr=args.lr,
                              # weight_decay=args.weight_decay)

# lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, args.lr_drop)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, args.lr_drop, gamma=0.1)

dataset_train,collate_fn = build_dataset(split='train', args=args)
dataset_val,collate_fn = build_dataset(split='test', args=args)

# dataset_train = UCF24(split='train',frame_sampling_rate=args.frame_sampling_rate)
# dataset_val = UCF24(split='test',frame_sampling_rate=args.frame_sampling_rate)

if args.distributed:
    sampler_train = DistributedSampler(dataset_train)
    sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
# sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=False)

data_loader_train = DataLoader(dataset_train,batch_sampler=batch_sampler_train, num_workers=args.num_workers, 
                        collate_fn=collate_fn, pin_memory=True)


data_loader_val = DataLoader(dataset_val,
                             args.batch_size,
                             sampler=sampler_val,
                             drop_last=False,
                             collate_fn=collate_fn,
                             pin_memory=True,
                             num_workers=args.num_workers)

# if args.frozen_weights is not None:
#     checkpoint = torch.load(args.frozen_weights, map_location='cpu')
#     model_without_ddp.rtd.load_state_dict(checkpoint['model'])

output_dir = Path(args.output_dir)
# if args.resume:
#     if args.resume.startswith('https'):
#         checkpoint = torch.hub.load_state_dict_from_url(args.resume,
#                                                         map_location='cpu',
#                                                         check_hash=True)
#     else:
#         print(("=> loading checkpoint '{}'".format(args.resume)))
#         checkpoint = torch.load(args.resume)
#         args.start_epoch = checkpoint['epoch']
#         pretrained_dict = checkpoint['model']
#         # only resume part of model parameter
#         model_dict = model_without_ddp.state_dict()
#         pretrained_dict = {
#             k: v
#             for k, v in pretrained_dict.items() if k in model_dict
#         }
#         model_dict.update(pretrained_dict)
#         model_without_ddp.load_state_dict(model_dict)
#         # main_model.load_state_dict(checkpoint['state_dict'])
#         print(("=> loaded '{}' (epoch {})".format(args.resume,
#                                                   checkpoint['epoch'])))

if args.load:
    checkpoint = torch.load(args.load, map_location='cpu')
    model_without_ddp.load_state_dict(checkpoint['model'])

best_video_map_05 = 0

best_frame_map = 0

if args.eval:
    evaluator, eval_loss_dict = evaluate(model, criterion, 
                                                 data_loader_val, device, args, postprocessors)
   


    results,_ = evaluator.summarize()
    
    results_pd,all_vid_links = get_all_links(results)

    results_pd.to_csv(args.output_dir+'results.csv',index=False)

    test_stats, _ = eval_props_detection_video_map(results_pd,args.output_dir,mode='time')

    print('test_stats', test_stats)
    return

if args.eval_all:
    load = [0]+list(np.arange(1,15,2))+[14]
    # load = [1]
    load_path = args.output_dir + 'checkpoint_epoch'

    for l in load:
        path = load_path + str(l) + '.pth'
        checkpoint = torch.load(path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        if args.stage == 1:
            evaluator, eval_loss_dict = evaluate(model, criterion, 
                                             data_loader_val, device, args, postprocessors)
        else:
            evaluator, eval_loss_dict = evaluate_active(model, criterion, 
                                         data_loader_val, device, args, postprocessors)


        results,_ = evaluator.summarize()
        
        results_pd,all_vid_links = get_all_links(results)
    # print('results_pd[\'label-idx\'] == 1 is: ',sum(results_pd['label-idx'] == 1))
    # results_pd.to_csv(args.output_dir+'results.csv',index=False)
     

        test_stats, _ = eval_props_detection_video_map(results_pd,args.output_dir)

        print('test_stats', test_stats)

        log_stats = {
            **{f'test_MAP@{k}': v
               for k, v in test_stats.items()}, 'epoch': int(l),
            'n_parameters': n_parameters}

        if (float(test_stats['0.5']['map']) > best_video_map_05):
            best_video_map_05 = float(test_stats['0.5']['map'])
            with (output_dir / 'log_best_map05.txt').open('w') as f:
                f.write(json.dumps(log_stats) + '\n')
            checkpoint_path = output_dir / 'checkpoint_best_map05.pth'
            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'epoch': l,
                    'args': args,
                }, checkpoint_path)
            results_pd.to_csv(args.output_dir+'best_results.csv',index=False)
            with open(args.output_dir+'best_vid_links.pickle','wb') as file:
                pickle.dump(all_vid_links,file, protocol=pickle.HIGHEST_PROTOCOL)
    return

print('Start training')
start_time = time.time()

epoch_list = []
train_loss_list = {}
eval_loss_list = {}
test_stats_list = {}


   
scaler = GradScaler()
for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        sampler_train.set_epoch(epoch)

    if args.stage == 1:
        train_stats, train_loss_dict = train_one_epoch(model, criterion,
                                                   data_loader_train,
                                                   optim, 
                                                   device,
                                                   epoch, scaler, args)
    else:
        if 'mill' in args.active_learning_type:
            batch_budget = args.active_budget / (args.epochs // args.active_cycle)
            train_stats, train_loss_dict, aggregator = train_one_epoch_active_mill(model, criterion,
                                                       data_loader_train, 
                                                       optim, 
                                                       device,
                                                       epoch, scaler, batch_budget, args)
            agg_flag = (epoch+1) % args.active_cycle == 0
        elif 'pseudo' in args.active_learning_type:
            if (epoch+1) <= args.warm_up_epochs:
                batch_budget = args.active_budget / 2
            else:
                batch_budget = (args.active_budget/2) / ((args.epochs - args.warm_up_epochs) // args.active_cycle)
            train_stats, train_loss_dict, aggregator = train_one_epoch_active_pseudo(model, criterion,
                                                       data_loader_train, 
                                                       optim, 
                                                       device,
                                                       epoch, scaler, batch_budget, args)
            agg_flag = ((epoch+1) % args.active_cycle == 0 and (epoch+1) >= args.warm_up_epochs)

        if agg_flag:
            acquisitions = aggregator.summarize()

            for k,v in acquisitions.items():
                dataset_train.person_detection_label_check[k] = v[-1].detach().cpu().numpy()

           
                # for k,v in acquisitions.items():
                #   print(np.any(data_loader_train.dataset.person_detection_label_check[k] == v[-1].cpu().numpy()))
            
            print('{} number of rois selected out of {}'.format(dataset_train.tot_rois('selected'),
                    dataset_train.tot_rois()),'\n')


    for key, value in train_loss_dict.items():
        if key in [
                'loss_mill','loss_supcon', 'loss_ce'
        ]:
            try:
                train_loss_list[key].append(value.mean().detach().cpu().numpy())
            except KeyError:
                train_loss_list[key] = [value.mean().detach().cpu().numpy()]

    lr_scheduler.step()

    if (epoch+1) % args.epochs == 0 and args.output_dir:
        checkpoint_path = output_dir / 'checkpoint_epoch{}.pth'.format(
            epoch)
        utils.save_on_master(
            {
                'model': model_without_ddp.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args,
            }, checkpoint_path)
    if args.output_dir:
        checkpoint_paths = [output_dir / 'checkpoint.pth']
        # extra checkpoint before LR drop and every 100 epochs

        if np.any([(epoch + 1) % l==0 for l in args.lr_drop]) or (epoch + 1) % args.epochs == 0:
            checkpoint_paths.append(output_dir /
                                    f'checkpoint{epoch:04}.pth')
        for checkpoint_path in checkpoint_paths:
             utils.save_on_master(
            {
                'model': model_without_ddp.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args,
            }, checkpoint_path)
    
    
    if (epoch+1) % args.val_check == 0:
        
        # if args.stage == 1:
        evaluator, eval_loss_dict = evaluate(model, criterion, 
                                                data_loader_val, device,args,postprocessors)
        # else:
        #   evaluator, eval_loss_dict = evaluate_active(model, criterion, 
        #                                    data_loader_val, device, args, postprocessors)
        results,_ = evaluator.summarize()
        
        results_pd,_ = get_all_links(results)
        test_stats, _ = eval_props_detection_video_map(results_pd,args.output_dir)

        
        for k, v in test_stats.items():

            try:
                if isinstance(v,dict):
                    k1 = list(v.keys())
                    test_stats_list[k].append(float(v[k1[0]]) * 100)
                else:
                    test_stats_list[k].append(float(v) * 100)
            except KeyError:
                if isinstance(v,dict):
                    k1 = list(v.keys())
                    test_stats_list[k] = [float(v[k1[0]]) * 100]
                else:
                    test_stats_list[k] = [float(v) * 100]

        for key, value in eval_loss_dict.items():
            if key in [
                    'loss_mill','loss_ce'
            ]:
                try:
                    eval_loss_list[key].append(value.mean().detach().cpu().numpy())
                except KeyError:
                    eval_loss_list[key] = [value.mean().detach().cpu().numpy()]

        print('test_stats', test_stats)

        # log_stats = {**{f'train_{k}': v
        #      for k, v in train_stats.items()}}
        log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_MAP@{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters,
                'percentage_selected_rois': dataset_train.tot_rois('selected')/dataset_train.tot_rois()}

        if (float(test_stats['0.5']['map']) > best_video_map_05):
            best_video_map_05 = float(test_stats['0.5']['map'])
            with (output_dir / 'log_best_map05.txt').open('w') as f:
                f.write(json.dumps(log_stats) + '\n')
            checkpoint_path = output_dir / 'checkpoint_best_map05.pth'
            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optim.state_dict(),
                    'scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args,
                }, checkpoint_path)

        

        # if args.output_dir :
        #   checkpoint_path = output_dir / 'checkpoint_epoch{}.pth'.format(
        #     epoch)
        #   utils.save_on_master(
        #     {
        #         'model': model_without_ddp.state_dict(),
        #         'optimizer': optim.state_dict(),
        #         'scheduler': lr_scheduler.state_dict(),
        #         'epoch': epoch,
        #         'args': args,
        #     }, checkpoint_path)

        if args.output_dir and utils.is_main_process():
          with (output_dir / 'log.txt').open('a') as f:
              f.write(json.dumps(log_stats) + '\n')
        epoch_list.append(epoch)

    



if args.stage ==2:
    pd_labeled_list = dataset_train.person_detection_label_check
    with open(args.output_dir+'labeled_roi_list.pickle','wb') as file:
        pickle.dump(pd_labeled_list,file)
        


# # total_time = time.time() - start_time
# # total_time_str = str(datetime.timedelta(seconds=int(total_time)))
# print('Training time {}'.format(total_time_str))

if name == ‘main’:
parser = argparse.ArgumentParser(‘training weakly but actively’,
parents=[get_args_parser()])
os.environ[
“TORCH_DISTRIBUTED_DEBUG”
] = “DETAIL”
args = parser.parse_args()
# sys.settrace(trace)
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
faulthandler.enable()
# faulthandler._sigsegv()
main(args)