How to customize load_state_dict?

I have a custom convolution layer that support two layouts: KRSC and RSKC, RS is kernel size, K is out channel, C is input channel, used by different algorithms. Now I want to load RSKC state dict to KRSC layer, I can’t find any useful hooks to achieve this. a hook named “_register_load_state_dict_pre_hook” exists but it’s a private method and don’t have guarantee that available in torch 1.5+.

You can manipulate the tensors in the state_dict directly before loading it.
Here is a simple example:

conv1 = nn.Conv2d(3, 6, 3)
conv2 = nn.Conv2d(3, 6, 2)

sd = conv1.state_dict()

# won't work
# > RuntimeError...

# manipulate
sd['weight'] = sd['weight'][:, :, :2, :2]

# works
# > <All keys matched successfully>