I’m working on an RNN that is broken down into several modules, and one of these modules estimates the accuracy of a given result. The RNN has multiple different ways of computing the result, and by estimating the accuracy of a result, it is able to gauge the efficacy of the technique. I want to use a transformer for this module, for it’s strong predictive power, but I can’t figure out a way to do it without recalculating results multiple times.
Each timestep, the module receives the input-output pair of tensors from the rest of the RNN. The result of the module is then used in the next step to modify the technique of the RNN. The issue is, each timestep I’m having to recalculate all the previous values.
i.e.
step 1: X1 -> RNN -> Y1; ((X1),(Y1))->Transformer->A1
step 2: X2 -> RNN -> Y2; ((X1,X2),(Y1,Y2))->Transformer->(A1,A2)
step 3: X3 -> RNN -> Y3; ((X1,X2,X3),(Y1,Y2,Y3))->Transformer->(A1,A2,A3)
I only need A1 from step 1, A2 from step 2, and A3 from step 3, but step 3 calculates A1, A2 and A3.
Ideally I would want:
step 1: X1 -> RNN -> Y1; (X1,Y1)->Transformer->A1
step 2: X2 -> RNN -> Y2; (X2,Y2)->Transformer->A2
step 3: X3 -> RNN -> Y3; (X3,Y3)->Transformer->A3
This shouldn’t affect the accuracy of the model as a whole, but it’s nearly impossible to train due to the giant computation graphs it creates–my computer runs out of memory almost instantly.
I’ve looked at KV-caching, but that seems to be exclusively for inference, and I also don’t know how to program that.
Maybe I could copy and paste one of those ‘make a transformer from scratch’ tutorials, but I’d prefer not to.
Anyone else have an idea how to solve this?