Multi Node DDP Training Profiler

I am training on 3 servers using distributed data parallelism with 1 gpu on each server. I have 3 GPUs in total. How can I profile such a training? Can I collect and analyze each worker’s data such as running times, memory status on the master?

Here is my trainer script:

import torch
import torch.nn.functional as F
import os
import time
import psutil
import argparse

from torch.utils.data import DataLoader
from dataloader import MyDataset
from wideresnet import build_wideresnet

from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group


def ddp_setup():
    init_process_group(backend="nccl")    

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str
    ) -> None:

        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.local_rank)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        

        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)

        self.model = DDP(self.model, device_ids=[self.local_rank], find_unused_parameters=True)

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.gpu_id}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
        

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.snapshot_path)
        print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        
        for epoch in range(self.epochs_run, max_epochs):
            
            b_sz = len(next(iter(self.train_data))[0])
            print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
    
            self.train_data.sampler.set_epoch(epoch)
            with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], profile_memory=True,
                                         on_trace_ready=torch.profiler.tensorboard_trace_handler(f"runs/gpu{self.global_rank}"), ) as prof:
        
                for b_id ,(source, targets) in enumerate(self.train_data):

                    source = source.to(self.local_rank)
                    targets = targets.to(self.local_rank)

                    self.optimizer.zero_grad()
                    output = self.model(source)
                    loss = F.cross_entropy(output, targets)
                    loss.backward()
                    self.optimizer.step()
                
                    prof.step()
            

            #if self.local_rank == 0 and epoch % self.save_every == 0:
            #    self._save_snapshot(epoch)
        
def load_train_objs(args):
    train_set = MyDataset(data_dir=args.data_dir, annotations=args.annotations, num_classes=args.num_classes)  # load your dataset
    model = build_wideresnet(depth=28, widen_factor=4, dropout=0, num_classes=args.num_classes)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer



def main(args, save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
    ddp_setup()   
    dataset, model, optimizer = load_train_objs(args)
    train_data = DataLoader(dataset, batch_size=batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset))  
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=50, type=int, help='Input batch size on each device (default: 32)')

    args = parser.parse_args()
    args.data_dir = "/media/data-science"
    args.annotations = "annotations.csv"
    args.num_classes = 20
    
    main(args,args.save_every, args.total_epochs, args.batch_size)