CUDA Out of Memory Error with jvp() in Large Model Training - Need Insights or Fixes

Hello PyTorch Community,

I am encountering a persistent CUDA out of memory (OOM) error during the training of a large model using BERT for sequence classification. The issue arises specifically around the use of the jvp() function for Jacobian-vector products in my training loop. Below, I’ve provided an overview of the relevant parts of my code and the specific issues I’m facing:
Note: I am noticing the memory footprint is higher than normal and when I try to profile it I am getting OOM error.

Code Snippet:

The outline code to show relevant parts of the implementation.

# Simplified snippet focusing on the problematic area
for epoch in range(num_epochs):
    for batch in train_loader:
        v_params = tuple([torch.randn_like(p).to(device) for p in params])
        
        # Forward pass and loss computation
        loss = model(batch)

        # Jacobian-vector product calculation
        _, jvp = torch.autograd.functional.jvp(fwd_function, (params,), (v_params,))
        
        # Gradient assignment
        for p, v in zip(params, v_params):
            p.grad = v * jvp

        del v_params
        optimizer.step()
        optimizer.zero_grad()

        torch.cuda.empty_cache()  # Attempt to free up unused memory

Issue Description:

  • The training process consistently results in an OOM error when calculating jvp().
  • Despite attempts to manage memory by deleting temporary variables and clearing the CUDA cache, the error persists.
  • The model trains without issues when the jvp() call is omitted, suggesting a significant memory overhead associated with this operation.

Attempts to Debug:

  • Memory profiling indicates high memory usage at the jvp() call, but I have not been able to reduce this usage effectively.

Questions:

  1. Has anyone experienced similar issues with jvp() in large model training scenarios?
  2. Are there known memory optimization strategies or alternative approaches to computing the Jacobian-vector product that might reduce memory consumption?
  3. Any recommendations for profiling tools or techniques that could provide more detailed insights into memory allocation and deallocation within PyTorch?

I appreciate any insights or suggestions the community might offer. Thank you for your help!