Inference speed of BertForSequenceClassification

Not sure if this is a better question for the transformers group, but moving from pytorch 1.3.0 to pytorch 1.9.0 ( transformers 2.3.0 to 4.7.0, simpletransformers 0.15.7 to 0.61.6 ) , and profiling my bert inference model, i found that the (python-profiled) rate limiting step was using two different calls in pytorch:

in 1.3.0, it was {method ‘matmul’ of ‘torch._C._TensorBase’ objects};
in 1.9.0, it was {built-in method torch._C._nn.linear}, and was taking 4x the time.