GPU MEM% allocation vs batch size and temporal dimension

I have a vision transformer training where the input has shape (B, T, C, H, W).

I notice that when I double the batch size B, GPU MEM% also increases by ~2x. On the other hand, when I compare single timestep training (T=1) with temporal trainings (e.g. T=16), it looks like in this particular example MEM% only increases by ~2x instead of 16x.

I tried to look at the shape of intermediate tensors and most of them have 16x more elements, as expected. Note that I use flash_attn, checkpointing and factorized_space_time_attn. I am not exactly sure how they affect the memory footprint though.

I would like to understand why the MEM% does not increase proportionally. Thanks for any insight!

In transformers, the space complexity (memory) increases linearly with batch size and quadratically with sequence length or the number of tokens (not to be confused with the number of frames).

However, different ways of extracting tokens from frames (for example, UniformTemporalSubsample, tublets) can significantly reduce the memory requirement, since they change how many tokens are actually passed into the transformer. In addition, patch size also affects the number of tokens. All these settings affect the memory requirement.

In my setup, the batch size and temporal channels are flattened into one channel of dimension B*T, before passing into the GlobalTranaformerLayer. So it is unclear why MEM grows linearly wrt to B while sub linearly wrt T.

What’s the process that takes these flattened tensors and turns them into patch embeddings? And how many patch embeddings go into the transformer when T=1 and when T=16?