Gather outputs from all GPUs on master GPU and use it as input to the subsequent layers

Hi everyone,

I am currently working on training a model comprising two components: an encoder and a classifier. Given that my data is relational, it’s imperative that the classifier utilizes embeddings computed by the encoder from all samples in the dataset. However, due to the substantial size of the encoder network, I’ve opted to utilize Distributed Data Parallelism (DDP) to manage the dataset effectively. This involves partitioning the dataset into batches, distributing each batch across different GPUs for encoding, and then gathering and concatenating the encodings on a single GPU (e.g., the root GPU) before passing them to the classifier (which will be on the same device). Although I’ve implemented this structure, I’m uncertain if it’s the most effective approach. Additionally, I’ve observed that the optimizer.step() function appears to be taking forever. Any assistance or insights on optimizing this implementation would be greatly appreciated.

import torch
import torch.nn as nn
import torch.distributed as dist


def gather_concat(input, rank, dim=0):
    group = dist.new_group(list(range(dist.get_world_size())))
    if rank == 0:
        # create an empty list we will use to hold the gathered values
        tensor_list = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.gather(input, gather_list=tensor_list, dst=0, group=group)
        return torch.cat(tensor_list, dim=0)
    else:
        dist.gather(input, gather_list=[], dst=0, group=group)
        return None


class Model(nn.Module):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, dataloader, adjacency, local_rank):
        emb_list, index_list = [], []
        for batch in dataloader:
        	feat = batch['features'].to(local_rank)      #feat shape: [B, N, D_feat]
        	index = batch['index'].to(local_rank)	#the index of the samples in the dataset

            emb = self.encoder(feat)      #emd shape: [B, N, D_emb]

            emb_list.append(emb)
            index_list.append(index)
        
        embbedings = torch.cat(emb_list, dim=0)
        indices = torch.cat(index_list, dim=0)

        embbedings = gather_concat(embbedings, local_rank)
        indices = gather_concat(indices, local_rank) 

        if embbedings is None:
            return None
        else:
        	_, indices = torch.sort(indices)
        	embeddings = embeddings[indices]      #sort the embeddings based on their indices to match the ordering with adjacency matrix
            embbedings = torch.flatten(embbedings, start_dim=1)    #shape: [B, N * D_emb]
            logits = self.classifier(embbedings, adjacency.to(local_rank)) #shape: [B, num_classes]
            return logits


def run_epoch(epoch):
    b_sz = len(next(iter(dataloader))['features'])
    dataloader.sampler.set_epoch(epoch)
    optimizer.zero_grad()
        
    logits = model(dataloader, adjacency, local_rank)
        
    if logits is not None:
        loss = F.cross_entropy(
        	logits[train_mask],
        	labels[train_mask].to(local_rank)
        	)
        loss.backward()
        optimizer.step()

Thanks a lot

Hi @ptrblck! Would you mind helping me with this issue?

I’m unsure why your current approach uses DDP for the encoder only instead of the full model. Could you describe what kind of advantage you are seeing in this approach as I would assume to see a larger communication overhead (to gather everything on the default device) as well as a low GPU utilization while only the default device performs some work.

Thanks a lot for your reply @ptrblck . Essentially, my classifier is a Graph Neural Network (GNN) that processes all node embeddings and the adjacency matrix concurrently to compute class scores, i.e. all node embeddings must be prepared before being inputted into the GNN. However, due to the size of my encoder, encoding all nodes on a single device isn’t feasible. Hence, I’ve resorted to DDP. However, using DDP for the GNN isn’t ideal because I want to avoid splitting my graph into smaller segments (which can disrupt the relational information between nodes).

I don’t understand this issue, since DDP would replicate the model and synchronize the gradients in the backward pass. The model itself will not be sharded.