Spliting model into three wihle JIT tracing

How do I split a model into three different modules while JIT tracing (as in splitting up into different pytorch archives)?

For context it is an RNN-T model for ASR, I need to split the model into the encoder network, the prediction network and the joint network.