Load only some layers pytorch

Hello,
i have model that i have trained on rgb images, now i want to train it on raw data but the training is really long (it took me a week to train the model on rgb data with a good gpu).
I would like to load the weights some of the layers from the model that i have trained on rgb data and reuse them in my new model that works with raw.
Basically only the first layers change, and i would like to load the weights in every other layer. does someone know how to do ?
Thanks a lot :slight_smile:

I think loading the entire trained state_dict and re-initializing the first layers would be the easiest and cleanest approach. Alternatively, you could also try to load the parameters using a per-layer approach, but this sounds unnecessary complicated.