Triplet Loss OOM CUDA (A100 + Small Model)

I’ve been banging my head against the wall for the past 4 days trying to figure out why the following code OOM’s on an colab and an A100 instance (40GB GPU). My dataset is 5.5GB total and should comfortably fit into memory, but I batch this via CPU and load to GPU using pytorch lightning to abstract all of this device management. I’ve highlighted where I think the issue is with [ISSUE??] in the code below.

For context, counting the byte-size * num elements lands me with 12MB for a batch size of 1024 elements. The model I am using seems to be 500MB and I cannot seem to get anything above a batch size of 16, resulting in 2 hours per epoch even on an A100, (34 days if I run this locally).

  | Name       | Type                  | Params
-----------------------------------------------------
0 | base_model | CodeT5pEmbeddingModel | 109 M 
1 | loss_fn    | TripletMarginLoss     | 0     
-----------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
439.225   Total estimated model params size (MB)

I’m pretty confident something is wrong with my code given all of these data points but I am not sure what else to check. It seems like things break down once I make the forward pass and compute gradients.

Memory allocated before forward pass: 419.25244140625 MB
Memory allocated after loading data: 419.25244140625 MB
Memory allocated after forward pass: 608.919921875 MB
# Fails as soon as gradient update begins post-forward
class TripletDataModule(LightningDataModule):
    def __init__(
        self,
        anchor_tensor,
        positive_tensor,
        negative_tensor,
        batch_size=32,
        num_workers=4,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = TensorDataset(anchor_tensor, positive_tensor, negative_tensor)

    def setup(self, stage=None):
        total_len = len(self.dataset)
        train_len = int(0.8 * total_len)
        val_len = int(0.1 * total_len)
        test_len = total_len - train_len - val_len

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            self.dataset, [train_len, val_len, test_len]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=self.num_workers,
        )


class TripletModel(pl.LightningModule):
    def __init__(self, model_name):
        super(TripletModel, self).__init__()
        self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)

    def forward(self, batch):
        """Forward pass of the triplet model."""
        anchor, positive, negative = batch

        anchor_embedding = self.base_model(anchor)
        positive_embedding = self.base_model(positive)
        negative_embedding = self.base_model(negative)
        return anchor_embedding, positive_embedding, negative_embedding

    def training_step(self, batch, batch_idx):
        """Training step for the triplet model."""
        anchor, positive, negative = batch

        # [ISSUE??] THIS ONLY RUNS with torch.no_grad():
        anchor_embedding, positive_embedding, negative_embedding = self.forward(batch)
        loss = self.loss_fn(anchor_embedding, positive_embedding, negative_embedding)
        self.log("loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        """Validation step for the triplet model."""
        with torch.no_grad():
            anchor_embedding, positive_embedding, negative_embedding = self.forward(
                batch
            )
            loss = self.loss_fn(
                anchor_embedding, positive_embedding, negative_embedding
            )
            self.log("val_loss", loss)
        return {"val_loss": loss}

    def test_step(self, batch, batch_idx):
        """Test step for the triplet model."""
        anchor_embedding, positive_embedding, negative_embedding = self.forward(batch)
        loss = self.loss_fn(anchor_embedding, positive_embedding, negative_embedding)
        self.log("test_loss", loss)
        return {"test_loss": loss}

    def configure_optimizers(self):
        # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        optimizer = torch.optim.SGD(
            self.parameters(), lr=0.01, momentum=0.9, nesterov=True
        )
        return optimizer

    def encode_function(self, input):
        """Encode a function string into a single embedding vector."""
        return self.base_model(input)

For additional context, I’ve added a lot of “optimizations” via lightning which did help, but fundamentally do not cut down on training time as much as I’d expect given data this simple. Note: I’ve done all the tokenization offline to save time and remove as many moving pieces as possible.

# Load some tensors
# ... 
data_module = TripletDataModule(
    anchor_tensor, positive_tensor, negative_tensor, batch_size=64, num_workers=12
)
data_module.setup()

# Run training
model = TripletModel(model_name)
trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    precision="16-mixed",
    log_every_n_steps=15,
    accumulate_grad_batches=10,
    gradient_clip_val=0.5,
    max_epochs=1000,
    max_steps=10,
    callbacks=[DeviceStatsMonitor(cpu_stats=True), checkpoint_callback],
    # profiler=profiler,
    default_root_dir=".../triplets/checkpoints/"
)

Any ideas on how to get this working? I’d expect an A100 should support batch sizes of 256+ given the data & model size, but not sure what is happening with the gradient computation or if I did something else that is creating a lot of hidden overhead.

Besides the input data and the model’s parameters and buffers the intermediate forward activations could use a lot of memory depending on the model architecture. I don’t know how you’ve measured the memory usage, but this post explains it in more detail for a ResNet.

1 Like

Thanks! As a follow-up, if I wanted to get a larger batch size or decrease training time I could watch this memory usage when I do things like:

  • Quantization
  • Mixed precision
  • Gradient accumulation

And I should* see the RAM usage decrease right? It seems like when I applied many of these optimizations, training time increased but perhaps convergence was faster… it’s hard to tell if these had any notable improvement. I basically tried everything mentioned in these guides:

I’m not too familiar with Quantization, but mixed-precision training could reduce the memory usage and speed up the training. Reducing the batch size and applying gradient accumulation would trade compute for memory.

Could I just use torch.cuda.max_memory_allocated() and call the training step to get an accurate read? Seems like loading this hit a peak of 14.30GB though my card supports 8GB? (RTX 2070). Trying again with quantization + LoRA to see if there is any change.

And I found your other code for model size (418mb).

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

torch.cuda.max_memory_allocated() will return the peak allocated memory, which can be helpful to e.g. see if you could increase the batch size etc.

I don’t understand this statement as you won’t be able to allocate memory memory than is available, so did max_memory_allocated() return 14GB?

Yes that’s what it’s showing after I do a single training step and moving device/batch to cuda. I was also able to load a batch size of 32, which showed 28.XGB and nvitop shows that the card is maxed out on resources.

How is this possible if your GPU supposedly has only 8GB?