Can't save/load model with state_dict

The following code snippet works:

trained_network = ...
torch.save(trained_network, 'final-model.pt')
new_network = torch.load('final-model.pt', map_location='cpu')

evaluation(new_network)  # new_network performs good

But using state_dict to save/load does not work:

trained_network = ...
torch.save(trained_network.state_dict(), 'final-model.pt')
new_network = init_model(...)
new_network.load_state_dict(torch.load('final-model.pt'))

evaluation(new_network)  # new_network performs bad, just like a model without training

I would like to use latter approach as it is the recommended way. How can I fix it ?

I figure out this problem now. I first compare the parameters name and values using the following code:

trained_network = Net() # network with trained parameters
my_network = Net() # network with default initialization

my_network.load_state_dict(trained_network.state_dict())

for ((k1, v1), (k2, v2)) in zip(my_network.state_dict().items(), trained_network.state_dict().items()):
    assert k1 == k2, "Parameter name not match"
    if not torch.equal(v1, v2):
        print("Parameter value not match", k1)

And then I found the parameter values are different in following module:

class ActNorm(nn.Module):

    def __init__(self, channel: int):
        super(ActNorm, self).__init__()

        self.logs = nn.Parameter(torch.zeros((1, channel, 1)))
        self.bias = nn.Parameter(torch.zeros((1, channel, 1)))
        self.eps = 1e-6
        self.is_inited = False

    def forward(self, x: Tensor):
        if not self.is_inited:
            self.__initialize(x)

        z = x * torch.exp(self.logs) + self.bias
        logdet = x.shape[2] * torch.sum(self.logs)
        return z, logdet

    def __initialize(self, x: Tensor):
        with torch.no_grad():
            bias = -torch.mean(x.detach(), dim=[0, 2], keepdim=True)
            logs = -torch.log(torch.std(x.detach(), dim=[0, 2], keepdim=True) + self.eps)
            self.bias.data.copy_(bias.data)
            self.logs.data.copy_(logs.data)
            self.is_inited = True

When I use torch.load(...), the above is_inited is set to True, while is_inited is set to False when using my_network.load_state_dict(...). As a result, the loaded parameters are always overwrited whenever I use my_network.load_state_dict(), which cause to bad performance.