OOM error while caching

Hi, my team and I were working with a ViT for a project but ran into a OOM issue. We have a couple of variants of ViT:

  1. Normal ViT
  2. ViT with some form of caching

The OOM error only occurs for ViT with caching but only on model sizes larger than 200M parameters (works fine for ViT of upto 76M parameters, i.e. it works fine for base and small, but not for large).

Right now we have tried the following to prevent memory leakages:

  1. Tried to avert it with detach (as can be seen)
  2. Checking if there was a broadcasting making the tensors bigger than they actually are (we printed out the cache in the forward pass and it had the same shape throughout, so it’s not looking like that’s the issue)
  3. Also tried to instantiate the cache using self.register_buffer(…) in the hopes that it would automatically fix any accumulated gradient issues but that didn’t work in our last run either - we’re not crystal clear on how these work in the context of the computational graph and tracking of gradients so this was a hail mary for us

I have attached part of the code where we work with the cache. The update_splits_cache_training is called after a window_size number of iterations have passed. Assume that the self.splits_func is set to get_ratios_diff.

We want to know whether our handling of cache has implications on back-prop which we have not figured out (for eg. using while loops, converting tensors to lists and tuples etc) which might be contributing to the OOM.

An interesting note is that after a window_size number of iterations, we also cleared the cache (torch.cuda.empty_cache) which actually allowed the training to go on much longer but eventually another OOM error showed up.

Any help would be really really appreciated! Thanks.

Detaching the curr_cache is a good idea, but since this does not seem to fix the issue, you could check if other tensors (which are accumulated or stored in any container) have a valid .grad_fn. If so, you could also detach these (assuming they are not needed in the backward pass).
If this does not help, you might need to add debug print statements into your training code trying to isolate which part of the code increases the memory usage unexpectedly.

1 Like