I’ve been banging my head against the wall for the past 4 days trying to figure out why the following code OOM’s on an colab and an A100 instance (40GB GPU). My dataset is 5.5GB total and should comfortably fit into memory, but I batch this via CPU and load to GPU using pytorch lightning to abstract all of this device management. I’ve highlighted where I think the issue is with [ISSUE??]
in the code below.
For context, counting the byte-size * num elements
lands me with 12MB for a batch size of 1024 elements. The model I am using seems to be 500MB and I cannot seem to get anything above a batch size of 16, resulting in 2 hours per epoch even on an A100, (34 days if I run this locally).
| Name | Type | Params
-----------------------------------------------------
0 | base_model | CodeT5pEmbeddingModel | 109 M
1 | loss_fn | TripletMarginLoss | 0
-----------------------------------------------------
109 M Trainable params
0 Non-trainable params
109 M Total params
439.225 Total estimated model params size (MB)
I’m pretty confident something is wrong with my code given all of these data points but I am not sure what else to check. It seems like things break down once I make the forward pass and compute gradients.
Memory allocated before forward pass: 419.25244140625 MB
Memory allocated after loading data: 419.25244140625 MB
Memory allocated after forward pass: 608.919921875 MB
# Fails as soon as gradient update begins post-forward
class TripletDataModule(LightningDataModule):
def __init__(
self,
anchor_tensor,
positive_tensor,
negative_tensor,
batch_size=32,
num_workers=4,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset = TensorDataset(anchor_tensor, positive_tensor, negative_tensor)
def setup(self, stage=None):
total_len = len(self.dataset)
train_len = int(0.8 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len
self.train_dataset, self.val_dataset, self.test_dataset = random_split(
self.dataset, [train_len, val_len, test_len]
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
pin_memory=True,
num_workers=self.num_workers,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
pin_memory=True,
num_workers=self.num_workers,
)
class TripletModel(pl.LightningModule):
def __init__(self, model_name):
super(TripletModel, self).__init__()
self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)
def forward(self, batch):
"""Forward pass of the triplet model."""
anchor, positive, negative = batch
anchor_embedding = self.base_model(anchor)
positive_embedding = self.base_model(positive)
negative_embedding = self.base_model(negative)
return anchor_embedding, positive_embedding, negative_embedding
def training_step(self, batch, batch_idx):
"""Training step for the triplet model."""
anchor, positive, negative = batch
# [ISSUE??] THIS ONLY RUNS with torch.no_grad():
anchor_embedding, positive_embedding, negative_embedding = self.forward(batch)
loss = self.loss_fn(anchor_embedding, positive_embedding, negative_embedding)
self.log("loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
"""Validation step for the triplet model."""
with torch.no_grad():
anchor_embedding, positive_embedding, negative_embedding = self.forward(
batch
)
loss = self.loss_fn(
anchor_embedding, positive_embedding, negative_embedding
)
self.log("val_loss", loss)
return {"val_loss": loss}
def test_step(self, batch, batch_idx):
"""Test step for the triplet model."""
anchor_embedding, positive_embedding, negative_embedding = self.forward(batch)
loss = self.loss_fn(anchor_embedding, positive_embedding, negative_embedding)
self.log("test_loss", loss)
return {"test_loss": loss}
def configure_optimizers(self):
# optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
optimizer = torch.optim.SGD(
self.parameters(), lr=0.01, momentum=0.9, nesterov=True
)
return optimizer
def encode_function(self, input):
"""Encode a function string into a single embedding vector."""
return self.base_model(input)
For additional context, I’ve added a lot of “optimizations” via lightning which did help, but fundamentally do not cut down on training time as much as I’d expect given data this simple. Note: I’ve done all the tokenization offline to save time and remove as many moving pieces as possible.
# Load some tensors
# ...
data_module = TripletDataModule(
anchor_tensor, positive_tensor, negative_tensor, batch_size=64, num_workers=12
)
data_module.setup()
# Run training
model = TripletModel(model_name)
trainer = pl.Trainer(
accelerator="auto",
devices="auto",
precision="16-mixed",
log_every_n_steps=15,
accumulate_grad_batches=10,
gradient_clip_val=0.5,
max_epochs=1000,
max_steps=10,
callbacks=[DeviceStatsMonitor(cpu_stats=True), checkpoint_callback],
# profiler=profiler,
default_root_dir=".../triplets/checkpoints/"
)
Any ideas on how to get this working? I’d expect an A100 should support batch sizes of 256+ given the data & model size, but not sure what is happening with the gradient computation or if I did something else that is creating a lot of hidden overhead.