Optimize multiple backward calls through part of network

My network consists of two parts, a feature extractor/convolutional upsampler and a 2D recurrent net taking the output of the upsampler and some previous activation maps (at the same resolution) producing another activation map. The RNN is running over multiple time steps (2 previous input activation maps per time step, 1 output map), all with the same upsampler output. I’m trying to backpropagate through both parts of the network at the same time without having to recompute the upsampled map at each timestep. In code my naive training loop looks like this:

for epoch in range():
    img, act_maps, target_maps = get_batch() # act_maps: T, 2, H, W, targets_map: T, H, W
    feature_map = upsample_net(img) # feature_map: 1, 128, H, W
    for in_act, target in zip(act_maps, target_maps):
        i = torch.cat((feature_map, in_act)) # i: 1, 130, H, W
        o = recurrent_net(i) # o: 1, H, W
        loss = criterion(o, target)

Obviously calling backward() T times will raise an error which retain_grads=True will suppress for the cost of incredibly slow operations as all gradients of previous time steps are carried around forever.

I’ve tried batching everything into a single forward pass by expanding feature_map to shape T, 128, H, W and just concatenating it with act_maps but this is too memory-inefficient. Likewise I’d like to avoid using minibatches as the hidden state isn’t reinitialized between time steps and it would make my data loading code fairly convoluted.

Does anybody have a clue on how to backpropagate through the upsampling network multiple times without retaining all the autograd garbage from the recurrent net’s previous runs?