Partially sharded training

I have a large (100s of GB) time series dataset with high frequency samples. I would like to fit a recurrent model to predict features of this dataset in the following way:

  1. Split the data up into intervals (say, 5 minutes in length)
  2. On each interval, compute in parallel F(incoming_state, parameters) = (score, outgoing_state, d_score_d_param, d_score_d_incoming_state, d_outgoing_state_d_param, d_outgoing_state_d_incoming_state)
  3. Using the gradients from 2), optimize the parameters over the entire dataset at once
  4. Repeat 2) and 3) until convergence.

The motivation is that step 3) can fit on a single GPU, because I am kind of “compressing” the data in step 2); step 2) is very big but can be parallelized, and hopefully I can get away with computing it less frequently by passing the gradient to step 3).

How can I implement such a framework using PyTorch?