How to distribute sequence length instead of batch size?

I am working on frames that are part of a video sequence, and was curious about whether torch has a method to distributed the sequence length instead of the batch size, similar to DDP ?

The primary reason is that I can only fit a batchsize of 32 on a single GPU for a still image object detection task. That is 32 images per batch batchsize, image dims . However since we are working on sequence we know have a sequence length dimensions. That is batchsize, image dims, sequence length.

I am assuming that I will be able to fit a single batch with a sequence length of 32 on to GPU. Now with DDP I can increase the effective total batchsize, however 32 frames per video for a video length of 420 frames leads to a loss of temporal information. I would like to increase the sequence length to 128 without running into CUDA memory allocation errors.

I am already using the following strategies to reduce memory allocation:

  1. mixed precision
  2. checkpointing

Is there away to distribute sequence length across GPUs or so ?