Problem Description:
I’m encountering an out-of-memory (OOM) error while running a PyTorch-based distributed training script with multi-GPU using torch.distributed.launch
. The model is too large for a single GPU, so I use multi-GPU training to distribute the load. However, when computing gradients on specific inputs (e.g., aux
tensor), the memory usage increases significantly, leading to OOM.
To isolate the issue, I created a reproducible example using a smaller model (BERT for sequence classification). While this toy example fits in memory even in distributed mode, my actual model doesn’t, despite seemingly similar setups. Moreover, if I disable gradient computation (aux.requires_grad_(False)
), the real model works fine.
I suspect that either:
- Gradient computation is significantly increasing the memory footprint in distributed mode.
- There might be inefficiencies in how gradients are handled across GPUs.
Below is the reproducible example code and the steps to reproduce the issue.
Code to Reproduce:
python
Copy code
import torch
from torch import nn
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class ExampleDataset(Dataset):
def __init__(self, tokenizer, size=10000, max_length=128):
self.tokenizer = tokenizer
self.texts = [f"This is example text {i}" for i in range(size)]
self.labels = torch.randint(0, 2, (size,)) # Binary classification
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
tokenized = self.tokenizer(
self.texts[idx],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = tokenized["input_ids"].squeeze(0)
aux = torch.rand((5,)) * 100 # Random auxiliary input
return {
"aux": aux,
"input_ids": input_ids,
"attention_mask": tokenized["attention_mask"].squeeze(0),
"labels": self.labels[idx],
}
class CustomModel(nn.Module):
def __init__(self, base_model_name="bert-base-uncased", num_labels=2):
super(CustomModel, self).__init__()
self.bert = BertForSequenceClassification.from_pretrained(base_model_name, num_labels=num_labels)
self.embedding_dim = self.bert.config.hidden_size
self.aux_linear = nn.Linear(5, self.embedding_dim)
def forward(self, input_ids, attention_mask, aux, labels=None):
aux_embedded = self.aux_linear(aux) # Embed auxiliary input
input_embeddings = self.bert.bert.embeddings(input_ids)
input_embeddings[:, 0, :] += aux_embedded # Modify embeddings
outputs = self.bert(
inputs_embeds=input_embeddings,
attention_mask=attention_mask,
labels=labels,
)
return outputs
def compute_saliency_maps(trainer, loader, device, repeat_factor=1000):
model = trainer.model
model.eval()
for _ in range(repeat_factor):
for batch in tqdm(loader, desc="Computing Saliency Maps"):
batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
batch["aux"].requires_grad_(True) # Enable gradients for aux
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
aux=batch["aux"],
labels=batch["labels"]
)
loss = outputs.loss
loss.backward() # Compute gradients
grads = batch["aux"].grad
print(f"Gradient norm: {grads.norm().item()}")
if __name__ == "__main__":
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = CustomModel()
dataset = ExampleDataset(tokenizer, size=200)
loader = DataLoader(dataset, batch_size=1000)
training_args = TrainingArguments(
output_dir="./results",
per_device_eval_batch_size=8,
do_train=False,
do_eval=True,
)
trainer = Trainer(
model=model,
args=training_args,
eval_dataset=dataset,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
compute_saliency_maps(trainer, loader, device, repeat_factor=3)
Steps to Reproduce:
- Run the script in single-GPU mode:
bash
python debug_playground.py
This works as expected, but for my real model, it causes an OOM error.
2. Run in multi-GPU mode with torch.distributed.launch
:
bash
python -m torch.distributed.launch --nproc_per_node=8 debug_playground.py
- Observe the GPU memory usage (
nvidia-smi
).
Observations:
- Single-GPU mode: OOM occurs with my real model, as expected for large parameter counts.
- Multi-GPU mode: The toy example works, but my real model still OOMs. Gradient computation for
aux
seems to cause a drastic increase in memory usage. - Without gradients for
aux
: The real model runs without issues in multi-GPU mode.
Questions:
- Why does enabling gradients for the
aux
tensor cause such a significant memory increase in distributed mode? - is there a way to distribute the gradients still in this setup?
Any insights or suggestions would be greatly appreciated!