Hello PyTorch Community,
I am encountering a persistent CUDA out of memory (OOM) error during the training of a large model using BERT for sequence classification. The issue arises specifically around the use of the jvp()
function for Jacobian-vector products in my training loop. Below, I’ve provided an overview of the relevant parts of my code and the specific issues I’m facing:
Note: I am noticing the memory footprint is higher than normal and when I try to profile it I am getting OOM error.
Code Snippet:
The outline code to show relevant parts of the implementation.
# Simplified snippet focusing on the problematic area
for epoch in range(num_epochs):
for batch in train_loader:
v_params = tuple([torch.randn_like(p).to(device) for p in params])
# Forward pass and loss computation
loss = model(batch)
# Jacobian-vector product calculation
_, jvp = torch.autograd.functional.jvp(fwd_function, (params,), (v_params,))
# Gradient assignment
for p, v in zip(params, v_params):
p.grad = v * jvp
del v_params
optimizer.step()
optimizer.zero_grad()
torch.cuda.empty_cache() # Attempt to free up unused memory
Issue Description:
- The training process consistently results in an OOM error when calculating
jvp()
. - Despite attempts to manage memory by deleting temporary variables and clearing the CUDA cache, the error persists.
- The model trains without issues when the
jvp()
call is omitted, suggesting a significant memory overhead associated with this operation.
Attempts to Debug:
- Memory profiling indicates high memory usage at the
jvp()
call, but I have not been able to reduce this usage effectively.
Questions:
- Has anyone experienced similar issues with
jvp()
in large model training scenarios? - Are there known memory optimization strategies or alternative approaches to computing the Jacobian-vector product that might reduce memory consumption?
- Any recommendations for profiling tools or techniques that could provide more detailed insights into memory allocation and deallocation within PyTorch?
I appreciate any insights or suggestions the community might offer. Thank you for your help!