PyTorch fully-connected network ~2x slower than tf.keras equivalent

I noticed that fitting a simple fully-connected network is close to 2x faster in tf.keras than in PyTorch. I have tested on an Nvidia A100 GPU and reproduced in a colab notebook.

This seems to be the case regardless of the input number of features, number of hidden layers and number of hidden units and appears to be more pronounced with larger batch sizes. I also noticed that using eager execution in the tf.keras model with run_eagerly=True seems to eliminate any performance gains.

Tested with pytorch 1.7.0+cu101 and Tensorflow 2.41 - please see the colab notebook below with a simple reproducer.

I’d be interested to know if I’m missing anything, why Tensorflow appears to be faster and if there is anything that can be done altered in the PyTorch code to match the Tensorflow performance.