My code like that:
def train(model, train_loader, optimizer, criterion, writer, epoch, device, local_rank, model_name=None):
total_loss, all_preds, all_labels = 0, [], []
model.train()
cv_log.info(f"dist-{local_rank} Training.........")
total_data = len(train_loader)
scaler = GradScaler()
for i, (image, labels) in enumerate(train_loader):
labels = labels.to(device, non_blocking=True).float()
image = image.to(device, non_blocking=True).float()
optimizer.zero_grad()
with autocast(device_type= 'cuda'):
logits = model(image)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
all_preds.append(logits.detach())
all_labels.append(labels.detach())
acc, macro_f1, macro_recall, balanced_acc = compute_metrics(logits.detach(), labels.detach())
if dist.get_rank()== 0 and writer is not None:
writer.add_scalar('Loss/Train-Batch', loss.item(), epoch * total_data + i)
writer.add_scalar('ACC/Train-Batch', acc, epoch * total_data + i)
writer.add_scalar('Macro-F1/Train-Batch', macro_f1, epoch * total_data + i)
writer.add_scalar('Macro-Recall/Train-Batch', macro_recall, epoch * total_data + i)
writer.add_scalar('Balance-Acc/Train-Batch', balanced_acc, epoch * total_data + i)
if i%200==0:
cv_log.info(f"dist-{local_rank} {model_name}--EPOCH: {epoch} || {i}/{len(train_loader)}"
f" Loss Value: {loss.item()} macro-f1:{macro_f1:.2f} macro-recall:{macro_recall:.2f} balanced_acc:{balanced_acc:.2f} acc: {acc:.2f}")
all_preds, all_labels = torch.cat(all_preds), torch.cat(all_labels)
acc, macro_f1, macro_recall, balanced_acc = compute_metrics(all_preds.detach(), all_labels.detach())
if dist.get_rank()== 0 and writer is not None:
writer.add_scalar('Loss/Train-Epoch', total_loss / len(train_loader), epoch)
writer.add_scalar('Acc/Train-Epoch', acc, epoch)
writer.add_scalar('Macro-F1/Train-Epoch', macro_f1, epoch)
writer.add_scalar('Macro-Recall/Train-Epoch', macro_recall, epoch)
writer.add_scalar('Balance-Acc/Train-Epoch', balanced_acc, epoch)
return total_loss / len(train_loader), acc, macro_f1, macro_recall, balanced_acc
def test(model, test_loader, criterion, writer, epoch, local_rank, device):
model.eval()
cv_log.info(f"dist-{local_rank} eva......")
total_loss, all_preds, all_labels = 0, [], []
total_data = len(test_loader)
with torch.no_grad():
with tqdm(total= total_data) as pbar:
for i, (image, labels) in enumerate(test_loader):
labels = labels.to(device, non_blocking=True).float()
image = image.to(device, non_blocking=True).float()
with autocast(device_type= 'cuda'):
logits = model(image)
loss = criterion(logits, labels)
total_loss += loss.item()
all_preds.append(logits)
all_labels.append(labels)
acc, macro_f1, macro_recall, balanced_acc = compute_metrics(logits.detach(), labels.detach())
writer.add_scalar('Loss/Test-Batch', loss.item(), epoch * total_data + i)
writer.add_scalar('ACC/Test-Batch', acc, epoch * total_data + i)
writer.add_scalar('Macro-F1/Test-Batch', macro_f1, epoch * total_data + i)
writer.add_scalar('Macro-Recall/Test-Batch', macro_recall, epoch * total_data + i)
writer.add_scalar('Balance-Acc/Test-Batch', balanced_acc, epoch * total_data + i)
pbar.update(1)
pbar.close()
all_preds, all_labels = torch.cat(all_preds), torch.cat(all_labels)
acc, macro_f1, macro_recall, balanced_acc = compute_metrics(all_preds, all_labels)
if dist.get_rank()== 0:
writer.add_scalar('Loss/Test-Epoch', total_loss / len(test_loader), epoch)
writer.add_scalar('Acc/Test-Epoch', acc, epoch)
writer.add_scalar('Macro-F1/Test-Epoch', macro_f1, epoch)
writer.add_scalar('Macro-Recall/Test-Epoch', macro_recall, epoch)
writer.add_scalar('Balance-Acc/Test-Epoch', balanced_acc, epoch)
return total_loss / len(test_loader), acc, macro_f1, macro_recall, balanced_acc
def main(args: ModelArgs):
args.save_dir = f'./{args.data_type}-save-models/'
if dist.get_rank()==0:
print("*"*30, f"\n{args}\n", "*"*30)
print(args.model_name)
if args.model_name== 'vit':
args.batch_size = 32
args.lr = 1e-3
elif args.model_name== 'resnet50':
args.batch_size = 58
args.lr = 1e-4
elif args.model_name== 'resnet101':
args.batch_size = 36
args.lr = 1e-4
elif args.model_name== 'resnet152':
args.batch_size= 24
args.lr = 1e-3
if dist.get_rank()==0:
os.makedirs(os.path.join(args.save_dir, args.model_name), exist_ok= True)
store_path = os.path.join(args.save_dir, args.model_name)
writer = SummaryWriter(f"{store_path}/scalar/")
else:
writer = None
# data
small_dataset = True # True
train_data = CVDataset(args.train_path, image_size= args.image_size, mask= args.data_type,
small_dataset= small_dataset)
test_data = CVDataset(args.test_path, image_size=args.image_size, mask= args.data_type,
small_dataset= small_dataset)
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, batch_size= args.batch_size, shuffle= False,
sampler = train_sampler, num_workers= 4)
test_dataloader = DataLoader(test_data, batch_size= args.batch_size, shuffle= False,
num_workers= 4)
# model config
model = CVModel(args).to(device)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
optimizer = optim.AdamW(model.parameters(), lr= args.lr)
# criterion = nn.CrossEntropyLoss().to(device)
class_counts = [10434, 78316, 213687, 166781, 38284, 5457, 311, 9]
criterion = ClassBalancedFocalLoss(class_counts, beta=0.999, gamma=2.0).to(device)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
best_acc = 0.0
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
train_loss, train_acc, train_macro_f1, train_macro_recall, train_balance_acc = train(model, train_dataloader,
optimizer, criterion, writer,
epoch, device, local_rank,
model_name= args.model_name)
scheduler.step(train_loss)
# if train_balance_acc>= best_acc and dist.get_rank()== 0:
# best_acc = train_balance_acc
# torch.save(model.module.state_dict(), f"{store_path}/best_model_{args.model_name}.pth")
# cv_log.info(f"Best model({best_acc}) saved!")
# cv_log.info(f"Epoch:{epoch}/{args.epochs} {args.model_name} Train Loss: {train_loss:.4f} Lr {scheduler.get_last_lr()[0]} "
# f"Train Acc: {train_acc:.4f}, Macro-F1: {train_macro_f1:.4f}, Macro-Reacall: {train_macro_recall:.4f} Banlance-Acc: {train_balance_acc:.2f}")
if dist.get_rank()== 0:
test_loss, test_acc, test_macro_f1, test_macro_recall, test_balance_acc = test(model, test_dataloader,criterion, writer,
epoch, local_rank, device)
if test_acc >= best_acc:
best_acc = test_acc
torch.save(model.module.state_dict(), f"{store_path}/best_model_{args.model_name}.pth")
cv_log.info(f"Best model({best_acc}) saved!")
cv_log.info(f"Epoch:{epoch}/{args.epochs} {args.model_name} Train Loss: {train_loss:.4f}"
f"Train Acc: {train_acc:.4f}, Macro-F1: {train_macro_f1:.4f}, Macro-Reacall: {train_macro_recall:.4f} Banlance-Acc: {train_balance_acc:.2f}")
cv_log.info(f"Epoch:{epoch}/{args.epochs} {args.model_name} Train Loss: {test_loss:.4f}"
f"Train Acc: {test_acc:.4f}, Macro-F1: {test_macro_f1:.4f}, Macro-Reacall: {test_macro_recall:.4f} Banlance-Acc: {test_balance_acc:.2f}")
dist.barrier()
dist.destroy_process_group()
but when i training, there are some problems in output:
2025-04-17 22:14:23 - INFO - cv_model.py:82 - Model-resnet50 Total Params: 22.42M Train Params: 22.42M
2025-04-17 22:14:24 - INFO - cv_model.py:82 - Model-resnet50 Total Params: 22.42M Train Params: 22.42M
2025-04-17 22:14:24 - INFO - cv_model.py:111 - dist-0 Training.........
2025-04-17 22:14:24 - INFO - cv_model.py:111 - dist-1 Training.........
2025-04-17 22:14:26 - INFO - cv_model.py:140 - dist-1 resnet50--EPOCH: 0 || 0/36 Loss Value: 0.13872744143009186 macro-f1:0.01 macro-recall:0.01 balanced_acc:0.01 acc: 0.03
2025-04-17 22:14:26 - INFO - cv_model.py:140 - dist-0 resnet50--EPOCH: 0 || 0/36 Loss Value: 0.15321113169193268 macro-f1:0.01 macro-recall:0.01 balanced_acc:0.01 acc: 0.03
2025-04-17 22:14:58 - INFO - cv_model.py:156 - dist-0 eva......
0%| | 0/18 [00:00<?, ?it/s]
2025-04-17 22:14:58 - INFO - cv_model.py:111 - dist-1 Training.........
I noticed that training on the validation set was already in progress, but a new training process started, causing a blockage between processes.
i get the problem, use the model.module(image)
can solve the problem.