Encountered an issue while performing DistributedSampler

My code like that:

def train(model, train_loader, optimizer, criterion, writer, epoch, device, local_rank, model_name=None):
    total_loss, all_preds, all_labels = 0, [], []
    
    model.train()
    cv_log.info(f"dist-{local_rank} Training.........")

    total_data = len(train_loader)
    scaler = GradScaler()
    for i, (image, labels) in enumerate(train_loader):
        labels = labels.to(device, non_blocking=True).float()
        image = image.to(device, non_blocking=True).float()
        optimizer.zero_grad()

        with autocast(device_type= 'cuda'):
            logits = model(image)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        all_preds.append(logits.detach())
        all_labels.append(labels.detach())
        acc, macro_f1, macro_recall, balanced_acc = compute_metrics(logits.detach(), labels.detach())
        if dist.get_rank()== 0 and writer is not None:
            writer.add_scalar('Loss/Train-Batch', loss.item(), epoch * total_data + i)
            writer.add_scalar('ACC/Train-Batch', acc, epoch * total_data + i)
            writer.add_scalar('Macro-F1/Train-Batch', macro_f1, epoch * total_data + i)
            writer.add_scalar('Macro-Recall/Train-Batch', macro_recall, epoch * total_data + i)
            writer.add_scalar('Balance-Acc/Train-Batch', balanced_acc, epoch * total_data + i)

        if i%200==0:
            cv_log.info(f"dist-{local_rank} {model_name}--EPOCH: {epoch} || {i}/{len(train_loader)}"
                        f" Loss Value: {loss.item()}  macro-f1:{macro_f1:.2f} macro-recall:{macro_recall:.2f} balanced_acc:{balanced_acc:.2f} acc: {acc:.2f}")
    
    all_preds, all_labels = torch.cat(all_preds), torch.cat(all_labels)
    acc, macro_f1, macro_recall, balanced_acc = compute_metrics(all_preds.detach(), all_labels.detach())   
    if dist.get_rank()== 0 and writer is not None:
        writer.add_scalar('Loss/Train-Epoch', total_loss / len(train_loader), epoch)
        writer.add_scalar('Acc/Train-Epoch', acc, epoch)
        writer.add_scalar('Macro-F1/Train-Epoch', macro_f1, epoch)
        writer.add_scalar('Macro-Recall/Train-Epoch', macro_recall, epoch)
        writer.add_scalar('Balance-Acc/Train-Epoch', balanced_acc, epoch)

    return total_loss / len(train_loader), acc, macro_f1, macro_recall, balanced_acc

def test(model, test_loader, criterion, writer, epoch, local_rank, device):
    model.eval()
    cv_log.info(f"dist-{local_rank} eva......")
    total_loss, all_preds, all_labels = 0, [], []
    
    total_data = len(test_loader)
    with torch.no_grad():
        with tqdm(total= total_data) as pbar:
            for i, (image, labels) in enumerate(test_loader):
                labels = labels.to(device, non_blocking=True).float()
                image = image.to(device, non_blocking=True).float()
                with autocast(device_type= 'cuda'):
                    logits = model(image)
                    loss = criterion(logits, labels)
                
                total_loss += loss.item()
                all_preds.append(logits)
                all_labels.append(labels)

                acc, macro_f1, macro_recall, balanced_acc = compute_metrics(logits.detach(), labels.detach())

                writer.add_scalar('Loss/Test-Batch', loss.item(), epoch * total_data + i)
                writer.add_scalar('ACC/Test-Batch', acc, epoch * total_data + i)
                writer.add_scalar('Macro-F1/Test-Batch', macro_f1, epoch * total_data + i)
                writer.add_scalar('Macro-Recall/Test-Batch', macro_recall, epoch * total_data + i)
                writer.add_scalar('Balance-Acc/Test-Batch', balanced_acc, epoch * total_data + i)
                pbar.update(1)
        pbar.close()
    
    all_preds, all_labels = torch.cat(all_preds), torch.cat(all_labels)
    acc, macro_f1, macro_recall, balanced_acc = compute_metrics(all_preds, all_labels)
    if dist.get_rank()== 0:
        writer.add_scalar('Loss/Test-Epoch', total_loss / len(test_loader), epoch)
        writer.add_scalar('Acc/Test-Epoch', acc, epoch)
        writer.add_scalar('Macro-F1/Test-Epoch', macro_f1, epoch)
        writer.add_scalar('Macro-Recall/Test-Epoch', macro_recall, epoch)
        writer.add_scalar('Balance-Acc/Test-Epoch', balanced_acc, epoch)

    return total_loss / len(test_loader), acc, macro_f1, macro_recall, balanced_acc

def main(args: ModelArgs):
    args.save_dir = f'./{args.data_type}-save-models/'
    if dist.get_rank()==0:
        print("*"*30, f"\n{args}\n", "*"*30)
        print(args.model_name)

    if args.model_name== 'vit':
        args.batch_size = 32
        args.lr = 1e-3
    elif args.model_name== 'resnet50':
        args.batch_size = 58
        args.lr = 1e-4
    elif args.model_name== 'resnet101':
        args.batch_size = 36
        args.lr = 1e-4
    elif args.model_name== 'resnet152':
        args.batch_size= 24
        args.lr = 1e-3

    if dist.get_rank()==0:
        os.makedirs(os.path.join(args.save_dir, args.model_name), exist_ok= True)
        store_path = os.path.join(args.save_dir, args.model_name)
        writer = SummaryWriter(f"{store_path}/scalar/")
    else:
        writer = None

    # data
    small_dataset = True # True
    train_data = CVDataset(args.train_path, image_size= args.image_size, mask= args.data_type,
                           small_dataset= small_dataset)
    test_data = CVDataset(args.test_path, image_size=args.image_size, mask= args.data_type,
                          small_dataset= small_dataset)
    train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, batch_size= args.batch_size, shuffle= False, 
                                  sampler = train_sampler, num_workers= 4)
    test_dataloader = DataLoader(test_data, batch_size= args.batch_size, shuffle= False,
                                num_workers= 4)
    
    # model config
    model = CVModel(args).to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, 
                                                      device_ids=[local_rank], 
                                                      output_device=local_rank)
    optimizer = optim.AdamW(model.parameters(), lr= args.lr)
    # criterion = nn.CrossEntropyLoss().to(device)

    class_counts = [10434, 78316, 213687, 166781, 38284, 5457, 311, 9]
    criterion = ClassBalancedFocalLoss(class_counts, beta=0.999, gamma=2.0).to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

    best_acc = 0.0
    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)
        train_loss, train_acc, train_macro_f1, train_macro_recall, train_balance_acc = train(model, train_dataloader,
                                                                                             optimizer, criterion, writer, 
                                                                                             epoch, device, local_rank, 
                                                                                             model_name= args.model_name)
        scheduler.step(train_loss)
        # if train_balance_acc>= best_acc and dist.get_rank()== 0:
        #     best_acc = train_balance_acc
        #     torch.save(model.module.state_dict(), f"{store_path}/best_model_{args.model_name}.pth")
        #     cv_log.info(f"Best model({best_acc}) saved!")
        #     cv_log.info(f"Epoch:{epoch}/{args.epochs} {args.model_name}  Train Loss: {train_loss:.4f} Lr {scheduler.get_last_lr()[0]} " 
        #                 f"Train Acc: {train_acc:.4f}, Macro-F1: {train_macro_f1:.4f}, Macro-Reacall: {train_macro_recall:.4f} Banlance-Acc: {train_balance_acc:.2f}")

        if dist.get_rank()== 0:
            test_loss, test_acc, test_macro_f1, test_macro_recall, test_balance_acc = test(model, test_dataloader,criterion, writer, 
                                                                 epoch, local_rank, device)
            if test_acc >= best_acc:
                best_acc = test_acc
                torch.save(model.module.state_dict(), f"{store_path}/best_model_{args.model_name}.pth")
                cv_log.info(f"Best model({best_acc}) saved!")
            cv_log.info(f"Epoch:{epoch}/{args.epochs} {args.model_name}  Train Loss: {train_loss:.4f}" 
                        f"Train Acc: {train_acc:.4f}, Macro-F1: {train_macro_f1:.4f}, Macro-Reacall: {train_macro_recall:.4f} Banlance-Acc: {train_balance_acc:.2f}")
            cv_log.info(f"Epoch:{epoch}/{args.epochs} {args.model_name}  Train Loss: {test_loss:.4f}" 
                        f"Train Acc: {test_acc:.4f}, Macro-F1: {test_macro_f1:.4f}, Macro-Reacall: {test_macro_recall:.4f} Banlance-Acc: {test_balance_acc:.2f}")
        dist.barrier()        
    dist.destroy_process_group()

but when i training, there are some problems in output:

2025-04-17 22:14:23 - INFO - cv_model.py:82 - Model-resnet50 Total Params: 22.42M Train Params: 22.42M
2025-04-17 22:14:24 - INFO - cv_model.py:82 - Model-resnet50 Total Params: 22.42M Train Params: 22.42M
2025-04-17 22:14:24 - INFO - cv_model.py:111 - dist-0 Training.........
2025-04-17 22:14:24 - INFO - cv_model.py:111 - dist-1 Training.........
2025-04-17 22:14:26 - INFO - cv_model.py:140 - dist-1 resnet50--EPOCH: 0 || 0/36 Loss Value: 0.13872744143009186  macro-f1:0.01 macro-recall:0.01 balanced_acc:0.01 acc: 0.03
2025-04-17 22:14:26 - INFO - cv_model.py:140 - dist-0 resnet50--EPOCH: 0 || 0/36 Loss Value: 0.15321113169193268  macro-f1:0.01 macro-recall:0.01 balanced_acc:0.01 acc: 0.03
2025-04-17 22:14:58 - INFO - cv_model.py:156 - dist-0 eva......
  0%|                                                                                                                                                                                                                                          | 0/18 [00:00<?, ?it/s]
2025-04-17 22:14:58 - INFO - cv_model.py:111 - dist-1 Training.........

I noticed that training on the validation set was already in progress, but a new training process started, causing a blockage between processes.


i get the problem, use the model.module(image) can solve the problem.