Distributed training With Skipping Training Steps

I am trying to train a distributed model based on if an instance is captured in some prediction. Basically, if the object of interest is detected in the scan, continue to loss and gradient descent. If not, skip to the next batch step. However, the difficulty is that the model requires find_unused_parameters=True with torch.nn.parallel.DistributedDataParallel to accommodate a flow control scheme, and when find_unused_parameters is True the model stops at the backward pass after a batch is skipped. How could this possibly be addressed? I’ve seen the static_graph argument in torch==1.11+, but it is not clear if that is the best solution or exactly if that is the approach.

I’ve made a simple model here to demonstrate what I am trying to do where I am forcing a skip rather than checking for a variable from training.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler
import os
import numpy as np


class SampleNet(nn.Module):
    def __init__(self):
        super(SampleNet, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)
        # self.unused_conv = nn.Conv2d(in_channels=2, out_channels=12, kernel_size=3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.conv(x)
        x = self.avg_pool(x).squeeze(-1).squeeze(-1)
        return x


class SampleDataset(data.Dataset):
    def __init__(self):
        super(SampleDataset, self).__init__()
        self.rand_samples = np.random.random((4, 1, 32, 32))
        self.gt = np.array([1, 1, 1, 1])

    def __len__(self):
        return 4

    def __getitem__(self, sample_idx):
        sample_torch = torch.from_numpy(self.rand_samples[sample_idx]).to(torch.float32)
        sample_gt_torch = torch.from_numpy(np.array([1.0])).to(torch.float32)

        return sample_torch, sample_gt_torch


def main(gpu):
    device = f'cuda:{gpu}'

    world_size = 2
    batch_size = 1

    rank = gpu

    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=world_size,
        rank=rank)

    net = SampleNet()
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank],
                                                    find_unused_parameters=False,
                                                    broadcast_buffers=False,
                                                    )

    scaler = GradScaler()
    optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3)

    epochs = 1

    train_dataset = SampleDataset()
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=1,
                              pin_memory=True, drop_last=True, sampler=train_sampler,
                              persistent_workers=True, prefetch_factor=1)

    print('starting training')
    for epoch in range(epochs):
        net.train()
        train_sampler.set_epoch(epoch)
        for sample_index, (sample, target) in enumerate(train_loader):
            optimizer.zero_grad()

            sample = sample.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            with autocast():
                pred_sample = net(sample)

            skip_sample = False
            if gpu == 1:
                if sample_index == 0:
                    skip_sample = True

            if skip_sample == False:
                loss = F.binary_cross_entropy_with_logits(input=pred_sample, target=target)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                print('backward completed: ', gpu)

            torch.distributed.barrier()

            if gpu == 0:
                print('Here "a": ', sample_index)
            if gpu == 1:
                print('Here "b": ', sample_index)

            torch.distributed.barrier()


    if gpu == 0:
        print('GPU 0 finished.')
    if gpu == 1:
        print('GPU 1 finished.')

    torch.distributed.barrier()

    if gpu == 0:
        print('GPU 0 after barrier.')

    if gpu == 1:
        print('GPU 1 after barrier.')


if __name__ == '__main__':
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    mp.set_start_method('spawn')
    tot_workers = 2
    mp.spawn(main, nprocs=tot_workers)
1 Like

for your program, seems that you need to remove “torch.distributed.barrier()”.

e.g., rank0 skips the first batch, rank1 does not skip the first batch, rank1 will launch gradient sync in the backward pass under the hood if you are using DDP, but rank0 never launches gradient sync. So rank0 and rank1 are de-synced, two ranks will never get to ‘torch.distributed.barrier()’. The whole program will eventually time out

The torch.distributed.barrier() is not the problem. It can be removed but one of the processes is still unable to finish. More realistically, epochs should be 2, and the torch.distributed.barrier() can be removed, and it shows the neither that one of the processes does not finish.

def main(gpu):
    device = f'cuda:{gpu}'

    world_size = 2
    batch_size = 1

    rank = gpu

    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=world_size,
        rank=rank)

    net = SampleNet()
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank],
                                                    find_unused_parameters=False,
                                                    broadcast_buffers=False,
                                                    )

    scaler = GradScaler()
    optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3)

    epochs = 2  # Was originally set to 1

    train_dataset = SampleDataset()
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=1,
                              pin_memory=True, drop_last=True, sampler=train_sampler,
                              persistent_workers=True, prefetch_factor=1)

    print('starting training')
    for epoch in range(epochs):
        net.train()
        train_sampler.set_epoch(epoch)
        for sample_index, (sample, target) in enumerate(train_loader):
            optimizer.zero_grad()

            sample = sample.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            with autocast():
                pred_sample = net(sample)

            skip_sample = False
            if gpu == 1:
                if sample_index == 0:
                    skip_sample = True

            if skip_sample == False:
                loss = F.binary_cross_entropy_with_logits(input=pred_sample, target=target)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                print('backward completed: ', GPU)
    
            # Block commented out
            # torch.distributed.barrier()
            #
            # if gpu == 0:
            #     print('Here "a": ', sample_index)
            # if gpu == 1:
            #     print('Here "b": ', sample_index)

            # torch.distributed.barrier()


    if gpu == 0:
        print('GPU 0 finished.')
    if gpu == 1:
        print('GPU 1 finished.')

    torch.distributed.barrier()

    if gpu == 0:
        print('GPU 0 after barrier.')

    if gpu == 1:
        print('GPU 1 after barrier.')


Any solution to that issue?

No, the only solution I found was using a lot of torch.distributed.barrier functions and skipping the whole step when a process encountered an issue and using mp.Value to pass between the different processes.