Load_state_dict miss some keys if I convert BatchNorm2d to torchScript the torchScript one

hi sir,
I’m playing other person’s crnn project ( GitHub - meijieru/crnn.pytorch: Convolutional recurrent network in pytorch )and try to convert the code to torchScript so I can run the model in the libtorch.
In order not to show the "_is_full_backward_hook : NoneType " error in libtorch, I trace all the modules and output it .

import torch
import torch.nn as nn


class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        tempLayer=torch.jit.trace(nn.LSTM(nIn, nHidden, bidirectional=True),torch.rand(nHidden,nIn,nIn))
        self.rnn = tempLayer
        tempLayer2 = torch.jit.trace(nn.Linear(nHidden * 2, nOut), torch.rand(nOut,nHidden * 2))
        self.embedding = tempLayer2

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output


class CRNN(nn.Module):

    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization = False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            tempLayer = torch.jit.trace(nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]),torch.rand(1,nIn,100,100))
            cnn.add_module('conv{0}'.format(i), tempLayer)

            if batchNormalization:

                # error happened here!
                tempLayer = torch.jit.trace(nn.BatchNorm2d(nOut), torch.rand(20,nOut, 35, 45))
                cnn.add_module('batchnorm{0}'.format(i), tempLayer)

            if leakyRelu:
                tempLayer = torch.jit.trace(nn.LeakyReLU(0.2, inplace=True), torch.rand(1, 1, 100, 100))
                cnn.add_module('relu{0}'.format(i), tempLayer)
            else:
                tempLayer = torch.jit.trace(nn.ReLU(True), torch.rand(2))
                cnn.add_module('relu{0}'.format(i), tempLayer)

        convRelu(0)
        tempLayer=torch.jit.trace(nn.MaxPool2d(2, 2),torch.rand(64,16,64))
        cnn.add_module('pooling{0}'.format(0), tempLayer)  # 64x16x64
        convRelu(1)
        tempLayer = torch.jit.trace(nn.MaxPool2d(2, 2), torch.rand(128, 8, 32))
        cnn.add_module('pooling{0}'.format(1), tempLayer)  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        tempLayer = torch.jit.trace( nn.MaxPool2d((2, 2), (2, 1), (0, 1)), torch.rand(256,4,16))
        cnn.add_module('pooling{0}'.format(2),
                       tempLayer)  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        tempLayer = torch.jit.trace(nn.MaxPool2d((2, 2), (2, 1), (0, 1)), torch.rand(512, 2, 16))
        cnn.add_module('pooling{0}'.format(3),
                       tempLayer)  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        #assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)

        conv = conv.permute(2, 0, 1)  # [w, b, c]

        # rnn features
        output = self.rnn(conv)

        return output

loadedModel=torch.load("/home/gino/crnn.pth")

model = CRNN(32, 1, 37, 256)
model.cpu()

model.load_state_dict(loadedModel, True)
rnnJit = torch.jit.script(model)


rnnJit.save("rnnJit.pt")

I want to load the pretrain model.

If I change the

cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))

to

tempLayer = torch.jit.trace(nn.BatchNorm2d(nOut), torch.rand(20,nOut, 35, 45))
cnn.add_module('batchnorm{0}'.format(i), tempLayer)

it would report the error , If I ignore it , the prediction result would go wrong.

RuntimeError: Error(s) in loading state_dict for CRNN:
Missing key(s) in state_dict: “cnn.batchnorm2.num_batches_tracked”, “cnn.batchnorm4.num_batches_tracked”, “cnn.batchnorm6.num_batches_tracked”.

if the nn.BatchNorm2d is not in jit.trace , the model’s forwarding function works fine in pytorch, but It would cause problem when be exported to pt file for c++ libtorch.

what should I do?