How to use DistributedDataParallel when loss needs the whole dataset + do gradient accumulation?

Hi!

I am working in image retrieval and I would like to compute the loss on the entire dataset.
In order to do that I have to first compute the features for the whole dataset, and then compute the loss in a batch-wise manner and do gradient accumulation (due to memory constraint).
I wanted to use DistributedDataParallel in order to speed-up my training, but did not manage to do it.
The training would be on a single node with 3 to 4 gpu’s.

This is an example of what I am basically trying to do :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from tqdm import tqdm

transform = Compose(
    (Resize((256,256)),
     CenterCrop(224),
     ToTensor(),
     Normalize(
         mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225])
    )
)
dts_train = CIFAR10("/users/r/ramzie/datasets/", train=True, transform=transform, download=True)

def get_loader(dts, sampler=None):
    return DataLoader(
        dts,
        batch_size=128,
        shuffle=False,
        drop_last=False,
        pin_memory=True,
        num_workers=10,
        sampler=sampler,
    )


class L2Norm(nn.Module):
    def forward(self, X):
        return F.normalize(X, dim=-1)

def criterion(di, lb, features, labels):
    scores = torch.mm(di, features.t())
    gt = lb.view(-1, 1) == labels.unsqueeze(0)
    return F.relu(-scores[gt]).mean() + F.relu(scores[~gt]).mean()

net = resnet18(pretrained=True)
net.fc = L2Norm()
_ = net.cuda()

opt = torch.optim.SGD(net.parameters(), 0.1)
scaler = torch.cuda.amp.GradScaler()


for e in range(2):
    
    loader = get_loader(dts_train)
    features = []
    labels = []
    # We first compute the features and labels for the whole dataset
    # This could a first distributed loop
    for (x, y) in tqdm(loader, 'computing features'):
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                feat = net(x.cuda())
            features.append(feat)
            labels.append(y)
    features = torch.cat(features)
    labels = torch.cat(labels).cuda()
    
    ####################################################
    ####################################################
    
    loader = get_loader(dts_train)
    # This is the bottleneck, would could also be distributed
    for (x, y) in tqdm(loader, 'accumulating gradient'):
        with torch.cuda.amp.autocast():
            di = net(x.cuda())
            lb = y.cuda()
            loss = criterion(di, lb, features, labels) / len(features)
            
        # gradient accumulation for the entire dataset
        scaler.scale(loss).backward()

    ####################################################
    ####################################################

    # only at the end perform optimization (full batch)
    scaler.step(opt)
    scaler.update()

Thank you if you have any time to help!

Hi, what is exactly is the issue that you run into when using DDP? Is the training not sped up as you’d expect or does it run into memory issues (since you mentioned memory constraint)?

Hi,

I had several issues, one of them indeed being that it does not speed up training as I would expect.
I have the following issues:

  • How to make sure that I have the correct features and labels variables in all GPU’s (when using DistributedSampler I do not have the exact amount of samples, 50000 vs 50001)
  • Then correctly get all the batches to do the gradient accumulation (same issue as above, and also not sure this is the proper way to do gradient accumulation in a distributed setting)
  • My distributed version of the code is not really more efficient (block of code in this reply)
    • With a single GPU: 128 s for 2 epochs, 315 s for 5 epochs
    • With 3 GPU’s: 101 s for 2 epochs, 235 s for 5 epochs
    • Which is roughly x1.3 increase in speed, is it expected ?
  • Is there something to know about the use of mixed precision in a distributed setting ?
import os
import logging
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel


class L2Norm(nn.Module):

    def forward(self, X):
        return F.normalize(X, dim=-1)


def criterion(di, lb, features, labels):
    scores = torch.mm(di, features.t())
    gt = lb.view(-1, 1) == labels.unsqueeze(0)

    loss = F.relu(-scores[gt]).mean() + F.relu(scores[~gt]).mean()
    return loss


def main_worker(rank, world_size, dts_train):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.distributed.barrier()

    net = resnet18(pretrained=True)
    net.fc = L2Norm()

    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net.to(rank, non_blocking=True)
    net = DistributedDataParallel(net, device_ids=[rank], output_device=rank)

    opt = torch.optim.SGD(net.parameters(), 0.1)
    scaler = torch.cuda.amp.GradScaler()

    for e in range(2):
        features_sampler = DistributedSampler(dts_train)
        features_loader = DataLoader(
            dts_train,
            batch_size=128,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=10,
            sampler=features_sampler,
        )
        features_sampler.set_epoch(e)

        features = []
        labels = []
        # We first compute the features and labels for the whole dataset
        for (x, y) in features_loader:
            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    feat = net(x.to(rank, non_blocking=True))
                features.append(feat)
                labels.append(y)
        features = torch.cat(features)
        labels = torch.cat(labels).to(rank, non_blocking=True)
        new_features = [torch.empty_like(features) for i in range(world_size)]
        new_labels = [torch.empty_like(labels) for i in range(world_size)]
        torch.distributed.barrier()

        dist.all_gather(new_features, features)
        dist.all_gather(new_labels, labels)
        new_features = torch.cat(new_features).to(rank, non_blocking=True)
        new_labels = torch.cat(new_labels).to(rank, non_blocking=True)

        del features, labels

        accumulation_sampler = DistributedSampler(dts_train)
        accumulation_loader = DataLoader(
            dts_train,
            batch_size=128,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=10,
            sampler=accumulation_sampler,
        )
        accumulation_sampler.set_epoch(e)
        for (x, y) in accumulation_loader:
            with torch.cuda.amp.autocast():
                di = net(x.to(rank, non_blocking=True))
                lb = y.to(rank, non_blocking=True)
                loss = criterion(di, lb, new_features, new_labels) / len(new_features)

            # gradient accumulation for the entire dataset
            scaler.scale(loss).backward()

        torch.distributed.barrier()
        scaler.step(opt)
        scaler.update()

    print("finished")
    dist.destroy_process_group()


if __name__ == '__main__':

    logging.basicConfig(
            format='%(asctime)s - %(levelname)s - %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            level=logging.INFO,
        )

    start = time()
    transform = Compose(
        (
         Resize((256, 256)),
         CenterCrop(224),
         ToTensor(),
         Normalize(
             mean=[0.485, 0.456, 0.406],
             std=[0.229, 0.224, 0.225])
        )
    )

    dts_train = CIFAR10("/users/r/ramzie/datasets/", train=True, transform=transform, download=False)

    def get_loader(dts, sampler=None):
        return DataLoader(
            dts,
            batch_size=128,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=10,
            sampler=sampler,
        )

    mp.spawn(
        main_worker,
        nprocs=3,
        args=(3, dts_train),
    )
    end = time()

    print(f"took: {end-start}")

Thanks !