UserWarning: Exporting a model to ONNX with a batch_size other than 1

I met this warning when converting CRNN to ONNX model, my code is as follows:

from torch import nn,onnx
import torch

class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output

class CRNN(nn.Module):

    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False, lstmFlag=True):
        """
        是否加入lstm特征层
        """
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]
        self.lstmFlag = lstmFlag

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        if self.lstmFlag:
            self.rnn = nn.Sequential(
                BidirectionalLSTM(512, nh, nh),
                BidirectionalLSTM(nh, nh, nclass))
        else:
            self.linear = nn.Linear(nh * 2, nclass)

    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        output = self.rnn(conv)
        return output


model = CRNN(32, 1, 5530, 256,n_rnn=2,leakyRelu=False,lstmFlag=True)
x = torch.rand(1,1,32,100)

onnx.export(
    model,x,"crnn.onnx",export_params=True,opset_version=11,# verbose=True,
    do_constant_folding=True,input_names=["input"],output_names=["output"],
    dynamic_axes={'input' : {0 : 'batch_size'},'output' : {1 : 'batch_size'}}
)

The full information of warning is :

/home/dai/py36env/lib/python3.6/site-packages/torch/onnx/symbolic_opset9.py:1377: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable lenght with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model.
  "or define the initial states (h0/c0) as inputs of the model. ")

How can I avoid this warning or how to define the initial states(h0/c0)?

hi,Have you found a solution for this problem?

Same problem over here. Any lead on solving this?

Save problem.Anyone solve it? I convert the crnn pytorch model to onnx and then convert into a openvino model, but the inference output shape in openvino is wrong.

Any update on this? I am also facing the same warning while converting pytorch based model to ONNX for dynamic batch size.

Besides, When I check the onnx, I got a Segmentation fault(core dumped)

    onnx_model=onnx.load("crnn.onnx")
    onnx.checker.check_model("crnn.onnx")

    ort_session=onnxruntime.InferenceSession("crnn.onnx")

Have your solved this issue?