State_dict for weight norm Conv2d doesn't give weight_g and weight_v keys

Hi, I’m trying to recreate a decoder with pretrained weights and I’m using a Conv2d layer with weight_norm. In the state_dict for the layer I created, there is only a single “weight” key when I’m expecting there to be a “weight_v” and “weight_g” as in the original model. Any ideas what could be causing this and how I can get the state_dict to take the “weight_v” and “weight_g” values. I need it to load pretrained weights onto the model I’m creating

As you know, basically weight-normalized modules have weight_g and weight_v unless nn.utils.remove_weight_norm is applied to those modules, as follows.

In [5]: wn_linear = weight_norm(nn.Linear(2, 2, bias=False))

In [6]: wn_linear.state_dict()
              tensor([[-0.2730, -0.1248],
                      [-0.3416, -0.0280]]))])

In [7]: nn.utils.remove_weight_norm(wn_linear).state_dict()
              tensor([[-0.2730, -0.1248],
                      [-0.3416, -0.0280]]))])

So, I’m not quite confident but one trick in my mind is, register weight normalization after load_state_dict as follows:

1 Like