Hi everyone,
I am currently working on training a model comprising two components: an encoder and a classifier. Given that my data is relational, it’s imperative that the classifier utilizes embeddings computed by the encoder from all samples in the dataset. However, due to the substantial size of the encoder network, I’ve opted to utilize Distributed Data Parallelism (DDP) to manage the dataset effectively. This involves partitioning the dataset into batches, distributing each batch across different GPUs for encoding, and then gathering and concatenating the encodings on a single GPU (e.g., the root GPU) before passing them to the classifier (which will be on the same device). Although I’ve implemented this structure, I’m uncertain if it’s the most effective approach. Additionally, I’ve observed that the optimizer.step()
function appears to be taking forever. Any assistance or insights on optimizing this implementation would be greatly appreciated.
import torch
import torch.nn as nn
import torch.distributed as dist
def gather_concat(input, rank, dim=0):
group = dist.new_group(list(range(dist.get_world_size())))
if rank == 0:
# create an empty list we will use to hold the gathered values
tensor_list = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
dist.gather(input, gather_list=tensor_list, dst=0, group=group)
return torch.cat(tensor_list, dim=0)
else:
dist.gather(input, gather_list=[], dst=0, group=group)
return None
class Model(nn.Module):
def __init__(self, encoder, classifier):
super().__init__()
self.encoder = encoder
self.classifier = classifier
def forward(self, dataloader, adjacency, local_rank):
emb_list, index_list = [], []
for batch in dataloader:
feat = batch['features'].to(local_rank) #feat shape: [B, N, D_feat]
index = batch['index'].to(local_rank) #the index of the samples in the dataset
emb = self.encoder(feat) #emd shape: [B, N, D_emb]
emb_list.append(emb)
index_list.append(index)
embbedings = torch.cat(emb_list, dim=0)
indices = torch.cat(index_list, dim=0)
embbedings = gather_concat(embbedings, local_rank)
indices = gather_concat(indices, local_rank)
if embbedings is None:
return None
else:
_, indices = torch.sort(indices)
embeddings = embeddings[indices] #sort the embeddings based on their indices to match the ordering with adjacency matrix
embbedings = torch.flatten(embbedings, start_dim=1) #shape: [B, N * D_emb]
logits = self.classifier(embbedings, adjacency.to(local_rank)) #shape: [B, num_classes]
return logits
def run_epoch(epoch):
b_sz = len(next(iter(dataloader))['features'])
dataloader.sampler.set_epoch(epoch)
optimizer.zero_grad()
logits = model(dataloader, adjacency, local_rank)
if logits is not None:
loss = F.cross_entropy(
logits[train_mask],
labels[train_mask].to(local_rank)
)
loss.backward()
optimizer.step()
Thanks a lot