I’ve successfully set up DDP with the pytorch tutorials, but I cannot find any clear documentation about testing/evaluation. I want to do 2 things:
- Track train/val loss in tensorboard
- Evaluate my model straight after training (in same script).
However, both of these fail: (1) consistently gives me 2 entries per epoch, even though I do not use a distributed sampler for the validation loss and it should only execute if gpu_id==0, (2) evaluation on the test set doesn’t even happen, or it is excessively slow. The print statement inside the if statement checking the GPU id never gets printed, even though I use the exact same if statement earlier for snapshot saving/val loss calculation and there it enters the if statement normally.
Only the train dataloader uses a distributed sampler.
Now I can live with the strange tensorboard graphs, but I would really like to be able to evaluate my model in the same script.
here’s my code:
# ython -m torch.distributed.launch main_train_psnr.py --opt options/swinir/train_swinir_sr_classical.json
import os
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from baselines.swinir.utils import EarlyStopper, calculate_psnr, calculate_ssim
from baselines.swinir.dataloader import get_swinir_dataloaders
from baselines.swinir.swinir import SwinIR
DATADIR = os.environ["DATADIR"]
DATASETSDIR = os.environ["DATASETSDIR"]
RESULTSDIR = os.environ["RESULTSDIR"]
SCALE = 2
def ddp_setup():
init_process_group(backend='nccl')
class Trainer:
def __init__(self, model: torch.nn.Module, train_data: DataLoader, val_data: DataLoader, test_data: DataLoader, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, save_every: int, snapshot_path: str, tensorboard_path: str) -> None:
self.gpu_id = int(os.environ["LOCAL_RANK"])
self.model = model.to(self.gpu_id)
self.train_data = train_data
self.val_data = val_data
self.test_data = test_data
self.optimizer = optimizer
self.scheduler = scheduler
self.save_every = save_every
self.snapshot_path = snapshot_path
self.epochs_run = 0
self.batch_size = 16
if self.gpu_id == 0:
self.logger = SummaryWriter(tensorboard_path)
else:
self.logger = None
self.early_stopper = EarlyStopper(patience=10)
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)
self.model = DDP(self.model, device_ids=[
self.gpu_id], find_unused_parameters=True)
def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.gpu_id}"
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
self.early_stopper = snapshot["EARLY_STOPPER"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
def _run_batch(self, source, targets, epoch):
self.optimizer.zero_grad()
output = self.model(source)
loss = F.l1_loss(output, targets)
if self.gpu_id == 0:
self.logger.add_scalar("Loss/train", loss, epoch)
loss.backward()
self.optimizer.step()
self.scheduler.step()
def _validate_batch(self, val_source, val_targets):
with torch.no_grad():
output = self.model(val_source)
return F.l1_loss(output, val_targets).item()
def _get_epoch_val_loss(self):
val_loss = 0
for source, targets in self.val_data:
source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)
val_loss += self._validate_batch(source, targets)
return val_loss
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)
self._run_batch(source, targets, epoch)
if self.gpu_id == 0:
self.model.eval()
val_loss = self._get_epoch_val_loss()
self.logger.add_scalar("Loss/val", val_loss, epoch)
self.model.train()
# if self.early_stopper.early_stop(val_loss):
# break
def _save_snapshot(self, epoch):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": epoch,
"EARLY_STOPPER": self.early_stopper,
}
torch.save(snapshot, self.snapshot_path)
print(
f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")
def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
print(f"[{self.gpu_id}] epoch {epoch}")
self._run_epoch(epoch)
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_snapshot(epoch)
def test(self):
psnr_values = []
ssim_values = []
self.model.eval()
with torch.no_grad():
for lr, hr in self.test_data:
print("test batch")
lr = lr.to(self.gpu_id)
sr = self.model(lr).to("cpu")
psnr_value = calculate_psnr(hr, sr).item()
psnr_values.append(psnr_value)
ssim_val = calculate_ssim(hr, sr).item()
ssim_values.append(ssim_val)
self.model.train()
psnr, ssim = (sum(psnr_values)/len(psnr_values),
sum(ssim_values)/len(ssim_values))
print(f"PSNR: {psnr} | SSIM: {ssim}")
def train_test(self, max_epochs: int):
self.train(max_epochs)
print("done training")
if self.gpu_id == 0:
print(f"[{self.gpu_id}] start eval...")
self.test()
def load_train_objs(dataset_name: str, image_size: int, scale: int):
scale = 2
train_batch_size=16
print("Load data ...")
test_loader, valid_loader, train_loader = get_swinir_dataloaders(
dataset_name, train_batch_size)
model = SwinIR(img_size=image_size, scale=scale, window_size=8, mlp_ratio=2,
embed_dim=180, upsampler='pixelshuffle') # load your model
optimizer = Adam(model.parameters(), lr=2e-4, weight_decay=0)
scheduler = MultiStepLR(
optimizer, [250000, 400000, 450000, 475000, 500000], 0.5)
return train_loader, valid_loader, test_loader, model, optimizer, scheduler
def main(save_every: int, total_epochs: int, snapshot_path: str = "/data1/wasalaj/snapshots/swinir/snapshot.pt", scale: int = 2):
ddp_setup()
train_loader, valid_loader, test_loader, model, optimizer, scheduler = load_train_objs(
"cerrado", 64, scale)
trainer = Trainer(model, train_loader, valid_loader, test_loader,
optimizer, scheduler, save_every, snapshot_path, "/data1/wasalaj/tensorboard/swinir_test")
trainer.train_test(total_epochs)
print("done training...")
destroy_process_group()