How to find out if layer is trained or not

Hello everyone,
I have a situation where a model would have different topologies each with a different forward function, during training I choose the topology which defines the forward function implementation. Later I need to be able to traverse through the trained layer without knowing which topology was used to train the model. Currently, when I load the model using net.load_state_dict(checkpoint) I get all the layers, those who are used in the chosen forward function & those who are not (trained & untrained layers). I have an idea but Idk if there is a better way or not. After doing torch.load I would loop over checkpoint.keys() and split the name over the ‘.’ to extract the first part i.e encoder_1.conv1.bias would be encoder_1 then I would collect those names in a list and use it to extract the data from net.load_state_dict(checkpoint). But this looks too hacky is not it?? is there is a more straightforward way??