Conditional gradient update in "DistributedDataParallel"

Hi all,
I want to update the weights if the loss value is less than some threshold. It works okay for the single-gpu case but gets halted (or sometimes throw gpu memory error) when using “DistributedDataParallel” on a single node.
Here is an example to reproduce the error. Can you folks help me to figure out this problem?

import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-n",
        "--nodes",
        default=1,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
    parser.add_argument(
        "-g", "--gpus", default=1, type=int, help="number of gpus per node"
    )
    parser.add_argument(
        "-nr", "--nr", default=0, type=int, help="ranking within the nodes"
    )
    parser.add_argument(
        "--epochs",
        default=2,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    args = parser.parse_args()
    args.world_size = args.gpus * args.nodes
    os.environ["MASTER_ADDR"] = "tcp://127.0.0.1"
    os.environ["MASTER_PORT"] = "23456"
    mp.spawn(train, nprocs=args.gpus, args=(args,))


class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc = nn.Linear(7 * 7 * 32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out


def train(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:23456",
        world_size=args.world_size,
        rank=rank,
    )
    torch.manual_seed(0)
    model = ConvNet()
    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    batch_size = 100
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)
    # Wrap the model
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    # Data loading code
    train_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=True,
        transform=transforms.ToTensor(),
        download=True,
    )
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        sampler=train_sampler,
    )

    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            
           # Get halts here
            if loss.item() > 1.8: 
                loss.backward()
            else:
                print("skipping batch:", loss.item())
            optimizer.step()                
            print("GPU:{}, Epoch [{}/{}], Step [{}/{}], Loss: {}".format(gpu,epoch + 1, args.epochs, i + 1, total_step, loss))
    if gpu == 0:
        print("Training complete in: " + str(datetime.now() - start))


if __name__ == "__main__":
    main()

@smth can you please help to solve this issue ?

if loss.item() > 1.8: 
    loss.backward()
else:
    print("skipping batch:", loss.item())

The above might be the cause of the problem. When using DistributedDataParallel, backward() pass will trigger gradient synchronization communication (all_reduce) across all processes, meaning that all processes need to agree on the number and order of all_reduce calls. However, the above code seems to skip the backward pass in some process but not guarantee to skip in other processes? If that is the case, then processes could run in to desync and cause hang.

Thanks for the explanation.
But, I want if the loss in any process exceed some threshold then no process should should do the gradient update. Is is achievable when using DistributedDataParallel ?

When using DistributedDataParallel (DDP), loss is a local var. DDP will not communicate loss across processes. In order to make this work, you can do the following on each process

  1. run forward on DDP model to calculate loss
  2. create tensor to represent whether the loss is larger than a threshold.
  3. use all_reduce or all_gather to collectively communicate this information to all processes.
  4. After 3, all processes will have the same view on whether they should launch backward+step or not, and hence they can avoid run into desync problems now.