Out-of-Memory Error in Multi-GPU Distributed with Torch and Hugging Face Trainer

Problem Description:

I’m encountering an out-of-memory (OOM) error while running a PyTorch-based distributed training script with multi-GPU using torch.distributed.launch. The model is too large for a single GPU, so I use multi-GPU training to distribute the load. However, when computing gradients on specific inputs (e.g., aux tensor), the memory usage increases significantly, leading to OOM.

To isolate the issue, I created a reproducible example using a smaller model (BERT for sequence classification). While this toy example fits in memory even in distributed mode, my actual model doesn’t, despite seemingly similar setups. Moreover, if I disable gradient computation (aux.requires_grad_(False)), the real model works fine.

I suspect that either:

  1. Gradient computation is significantly increasing the memory footprint in distributed mode.
  2. There might be inefficiencies in how gradients are handled across GPUs.

Below is the reproducible example code and the steps to reproduce the issue.


Code to Reproduce:

python

Copy code

import torch
from torch import nn
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class ExampleDataset(Dataset):
    def __init__(self, tokenizer, size=10000, max_length=128):
        self.tokenizer = tokenizer
        self.texts = [f"This is example text {i}" for i in range(size)]
        self.labels = torch.randint(0, 2, (size,))  # Binary classification
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokenized = self.tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        input_ids = tokenized["input_ids"].squeeze(0)
        aux = torch.rand((5,)) * 100  # Random auxiliary input

        return {
            "aux": aux,
            "input_ids": input_ids,
            "attention_mask": tokenized["attention_mask"].squeeze(0),
            "labels": self.labels[idx],
        }


class CustomModel(nn.Module):
    def __init__(self, base_model_name="bert-base-uncased", num_labels=2):
        super(CustomModel, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained(base_model_name, num_labels=num_labels)
        self.embedding_dim = self.bert.config.hidden_size
        self.aux_linear = nn.Linear(5, self.embedding_dim)

    def forward(self, input_ids, attention_mask, aux, labels=None):
        aux_embedded = self.aux_linear(aux)  # Embed auxiliary input
        input_embeddings = self.bert.bert.embeddings(input_ids)
        input_embeddings[:, 0, :] += aux_embedded  # Modify embeddings

        outputs = self.bert(
            inputs_embeds=input_embeddings,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs


def compute_saliency_maps(trainer, loader, device, repeat_factor=1000):
    model = trainer.model
    model.eval()

    for _ in range(repeat_factor):
        for batch in tqdm(loader, desc="Computing Saliency Maps"):
            batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
            batch["aux"].requires_grad_(True)  # Enable gradients for aux

            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                aux=batch["aux"],
                labels=batch["labels"]
            )
            loss = outputs.loss
            loss.backward()  # Compute gradients
            grads = batch["aux"].grad
            print(f"Gradient norm: {grads.norm().item()}")


if __name__ == "__main__":
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = CustomModel()

    dataset = ExampleDataset(tokenizer, size=200)
    loader = DataLoader(dataset, batch_size=1000)

    training_args = TrainingArguments(
        output_dir="./results",
        per_device_eval_batch_size=8,
        do_train=False,
        do_eval=True,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        eval_dataset=dataset,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    compute_saliency_maps(trainer, loader, device, repeat_factor=3)

Steps to Reproduce:

  1. Run the script in single-GPU mode:

bash

python debug_playground.py

This works as expected, but for my real model, it causes an OOM error.
2. Run in multi-GPU mode with torch.distributed.launch:

bash

python -m torch.distributed.launch --nproc_per_node=8 debug_playground.py
  1. Observe the GPU memory usage (nvidia-smi).

Observations:

  • Single-GPU mode: OOM occurs with my real model, as expected for large parameter counts.
  • Multi-GPU mode: The toy example works, but my real model still OOMs. Gradient computation for aux seems to cause a drastic increase in memory usage.
  • Without gradients for aux: The real model runs without issues in multi-GPU mode.

Questions:

  1. Why does enabling gradients for the aux tensor cause such a significant memory increase in distributed mode?
  2. is there a way to distribute the gradients still in this setup?

Any insights or suggestions would be greatly appreciated!