I have a custom convolution layer that support two layouts: KRSC and RSKC, RS is kernel size, K is out channel, C is input channel, used by different algorithms. Now I want to load RSKC state dict to KRSC layer, I can’t find any useful hooks to achieve this. a hook named “_register_load_state_dict_pre_hook” exists but it’s a private method and don’t have guarantee that available in torch 1.5+.
You can manipulate the tensors in the
state_dict directly before loading it.
Here is a simple example:
conv1 = nn.Conv2d(3, 6, 3) conv2 = nn.Conv2d(3, 6, 2) sd = conv1.state_dict() # won't work conv2.load_state_dict(sd) # > RuntimeError... # manipulate sd['weight'] = sd['weight'][:, :, :2, :2] # works conv2.load_state_dict(sd) # > <All keys matched successfully>