Here is the minimal code of my training setup.
"""import os
import random
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.transforms as transforms
from data.Datasets import CustomDataset
from models.gmic import gmic, Evaluator, Trainer, LossFunctions
from utils.Configv2 import Configv2
def setup_ddp_env(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def create_dataset(cfg):
transforms = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)])
train = CustomDataset(cfg.train_path, transform=transform_train)
val = CustomDataset(cfg.val_path, transform=transform_val)
return train, val
def create_dataloader(dataset, rank, world_size, batch_size=16, pin_memory=False, num_workers=0):
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers,
drop_last=False, shuffle=False, sampler=sampler)
return dataloader
def create_model(cfg, pretrained):
model = gmic.GMIC(cfg.gmic_parameters)
model.load_state_dict(torch.load(pretrained), strict=False)
return model
def create_training_setup(rank, world_size, cfg, pretrained):
train, val = create_dataset(cfg)
train_loader = create_dataloader(train, rank, world_size)
val_loader = create_dataloader(val, rank, world_size)
model = create_model(cfg, pretrained)
model = model.to(rank)
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
criterion = LossFunctions.GMICLoss()
optimizer = optim.Adam(model.parameters(), lr=cfg.train.lr, weight_decay=0.001)
trainer = Trainer.Trainer(criterion=criterion, model=model, optimizer=optimizer,
total_epochs=cfg.train.epoch, train_loader=train_loader, parellel_mode=True)
evaluator = Evaluator.Evaluator(model=model, data_loader=val_loader)
return train_loader, trainer, val_loader, evaluator
def main(rank, world_size, cfg, pretrained, output_path, weight_path):
setup_ddp_env(rank, world_size)
train_loader, trainer, val_loader, evaluator = create_training_setup(rank, world_size, cfg, pretrained)
for epoch in range(1, cfg.train.epoch):
train_loader.sampler.set_epoch(epoch)
val_loader.sampler.set_epoch(epoch)
train_metrics = trainer.fit()
if epoch % 10 == 0:
val_metrics = evaluator.evaluate()
dist.destroy_process_group()
if __name__ == '__main__':
cfg = Configv2('config_path')
pretrained = 'pretrained_model_path'
torch.manual_seed(0)
random.seed(0)
mp.spawn(
main,
args=(2, cfg, pretrained, output_path, weight_path),
nprocs=2
)"""