Onnx performance runtime problem related to exporting

Hi, guys!

I found weird case with onnx runtime performance:
I trained resnext50_32x4d with 3 heads (head = conv2d(…, groups=64)-relu-linear) via pytorch-lightnings. Model code is below. Then I exported model to onnx via torch.onnx.export with opset_version=13 (using code from example) and ran by onnxruntime on CPU. I got approximately 225 ms per image in a single-thread mode.

But if I round-off manually model weights with 16 digits after decimal point (using load_model(…, n_digits=16) func) before exporting to onnx, I got 75 ms per image - 3x speed-up. I checked output tensors with 16-digits precision and it completely matched. And I checked its rounding did not change the model weights.

I couldn’t reproduce this case with pretrained network and I can not share my trained model, sorry.

Could someone explain it? Any suggestions?

Versions:

  • onnxruntime==1.8.0
  • onnx==1.9.0
  • torch=1.8.1
  • torchvision==0.9.1
  • pytorch_lightnings==1.3.6

Hardware: Intel (R) Core™ i5-10600K CPU @ 4.10GHz

Single-thread mode CPU:

  • OMP_NUM_THREADS = 1
  • OPENBLAS_NUM_THREADS = 1
  • MKL_NUM_THREADS = 1
  • VECLIB_MAXIMUM_THREADS = 1
  • NUMEXPR_NUM_THREADS = 1
  • (onnxruntime) intra_op_num_threads = 1
  • (onnxruntime) inter_op_num_threads = 1
# load model from state_dict
def load_model(model, path, n_digits =16):
    state_dict = torch.load(path)
    if n_digits:
        for key, value in state_dict.items():
            state_dict[key] = ((torch.tensor(value.numpy(), dtype=torch.float32) * 10**n_digits).round() / (10**n_digits))        
    model.load_state_dict(state_dict)
    model.eval()


# model
class ConvActLin(torch.nn.Module):
    def __init__(
        self,
        in_channels_2d,
        out_channels_2d,
        num_classes,
        kernel_size_2d=(1, 1),
        activation=torch.nn.ReLU,
        **kwargs
    ):
        super().__init__()
        self.head = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels_2d, out_channels_2d, kernel_size_2d, **kwargs),
            activation(inplace=True),
            torch.nn.Flatten(),
            torch.nn.Linear(out_channels_2d, num_classes)
        )

    def forward(self, x):
        x = x.unsqueeze(-1).unsqueeze(-1)
        return self.head(x)


class NetModule(torch.nn.Module):
    def __init__(self, num_classes: List[int], use_pretrained: bool = True, **kwargs):
        super().__init__()
        self.model = models.resnext50_32x4d(pretrained=use_pretrained)
        in_features = self.model.fc.in_features
        self.model.fc = torch.nn.Identity()
        self.layer_lbl1 = ConvActLin(
            in_features, in_features, num_classes[0], **kwargs)
        self.layer_lbl2 = ConvActLin(
            in_features, in_features, num_classes[1], **kwargs)
        self.layer_lbl3 = ConvActLin(
            in_features, in_features, num_classes[2], **kwargs)

    def forward(self, x):
        x = self.model(x)
        lbl1 = self.layer_lbl1(x)
        lbl2 = self.layer_lbl2(x)
        lbl3 = self.layer_lbl3(x)
        return lbl1, lbl2, lbl3