How to let BatchNorm2d process three-dimensional input in the inference stage

class CausalTransConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding, output_padding):
        super(CausalTransConv, self).__init__()
        self.trans_conv = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size,
                                             stride=stride, padding=padding, output_padding=output_padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.activation = nn.PReLU()                                     

    def forward(self, x):
        T = x.size(-1)
        out = self.trans_conv(x)[..., :T]
        out = self.norm(out)
        out = self.activation(out)   
        return out

A simple model is as above, in the training phase I use 4-dimensional data like [B,N,H,W], in the inference phase I can use 4-dimensional input using pytorch. But when deployed on the end side, the inference framework I use can only accept 3-dimensional input, without the dimension of Batch. The above con2d and PReLU can handle 3-dimensional data input which means that [1,N,H,W] and [N ,H,W] can get numerically equal results, but BatchNorm2d can only input [1,N,H,W]. How should I change it? When jit.trace can input [N,H,W]。