Yeah, tough question but an interesting one.
I guess you need to create an “unrolled” network with gradients flowing through it… sum your loss at each timestep, then apply the loss to the output of the final timestep. Then, as long as the gradients are all attached, theoretically they should flow back through “time”.
This post implies you can achieve that by using the same variable for input and output. Maybe give that a shot?
# non-truncated
for t in range(T):
out = model(out)
out.backward()
# truncated to the last K timesteps
for t in range(T):
out = model(out)
if T - t == K:
out.detach()
out.backward()