How to do .forward on part of my model?

The background of this question is as same as this one. I realized that to train that middle layer parameter I can save the input of that layer and do forward from here, so that the training processs could be finished even faster. But I am not sure about how to implement this in pytorch. Do I have to rewrite the .forward function of the entire model, or there is a simpler way to achieve that?

I think creating a new custom model with a new forward method would be the cleanest approach.
Alternatively, you could also use the modules by directly accessing them via:

middle_activation = ...
out = model.middle_layer1(middle_activation)
out = model.middle_layer2(out)

but this would basically be equivalent to writing a new custom forward.

Thanks!I’m going to try it out.