Hi, I’m trying to recreate a decoder with pretrained weights and I’m using a Conv2d layer with weight_norm. In the state_dict for the layer I created, there is only a single “weight” key when I’m expecting there to be a “weight_v” and “weight_g” as in the original model. Any ideas what could be causing this and how I can get the state_dict to take the “weight_v” and “weight_g” values. I need it to load pretrained weights onto the model I’m creating
As you know, basically weight-normalized modules have weight_g
and weight_v
unless nn.utils.remove_weight_norm
is applied to those modules, as follows.
In [5]: wn_linear = weight_norm(nn.Linear(2, 2, bias=False))
In [6]: wn_linear.state_dict()
Out[6]:
OrderedDict([('weight_g',
tensor([[0.3002],
[0.3428]])),
('weight_v',
tensor([[-0.2730, -0.1248],
[-0.3416, -0.0280]]))])
In [7]: nn.utils.remove_weight_norm(wn_linear).state_dict()
Out[7]:
OrderedDict([('weight',
tensor([[-0.2730, -0.1248],
[-0.3416, -0.0280]]))])
So, I’m not quite confident but one trick in my mind is, register weight normalization after load_state_dict
as follows:
1 Like