Hi all,
After several years of applying Deep Learning using Keras/TensorFlow, I recently tried to convert a rather simple image classification task from TensorFlow/Keras to PyTorch/Lightning. Basically, everything works, however Torch is not hitting the same accuracy as Keras does. After spending about two weeks of comparing and analyzing - mostly based on topics I found here - without resolving the issue, I decided to ask here. I’m trying to explain my application step by step, without missing anything important as well as skipping all trivial stuff.
In Keras, the images are loaded / processed using a custom class derived from tf.keras.utils.Sequence
. In Torch, I created an equivalent torch.utils.data.Dataset
class that is passed to a torch.utils.data.DataLoader
instance. The batches are loaded in both ways using the same code (wrapped in the respective class) such that data loading and processing is the same. Additionally, correct loading of the images was verified by showing the images after loading. Before forwarding by the model, they are scaled to [0,1] by a mutliplication with 1/255 in either case.
The model used is a simple CNN consisting only of Conv2D, ReLU, MaxPool and fully connected layers. The conversion is relatively straightforward:
tf.keras.layers.Conv2D(channels, kernel_size, activation=tf.relu),
tf.keras.layers.MaxPool2D((2,2)),
translates to
nn.Conv2d(channels_in, channels_out, kernel_size),
nn.ReLU(),
nn.MaxPool2d((2,2))
and similarly after flatten the output after the convolution part,
tf.keras.layers.Dense(size, activation=tf.nn.relu)
translates to
nn.Linear(size_in, size_out),
nn.ReLU()
both used with a tf.keras.models.Sequential
and nn.Sequential
model, respectively.
Verifying both models using model.summary() and torchinfo.summary() shows the same layer structure, i.e. same shapes with an identical number of parameters.
Here, I noticed two significant differences between Keras and Torch that were a bit tricky: In Keras, the final softmax classification layer is included in the model and the loss computation, whereas in Torch, the loss computation expects unsoftmaxed logits. Additionally, the layer weights and biases are initialized differently. Keras uses zeros for the biases and Xavier uniform for the weights, the Torch equivalents are
torch.nn.init.zeros_(layer.bias)
torch.nn.init.xavier_uniform_(layer.weight)
applied to all layers that contain parameters (i.e. convolution and linear).
In Keras, the model was compiled (model.compile()
) using a default Adam optimizer and trained using model.fit()
.
To do the same in Torch, I’ve implemented a small lightning.LightningModule class. The optimizer is the same, Adam uses eps=1e-8
in Torch be default, changed this to 1e-7
as in Keras, everything else is the same. In particular, learning rate is the same in both settings (0.001), the same holds for the batch size. Passing this configuration to a lightning.Trainer()
yields pretty much the same training functionality as Keras.
The image data used consitis of ~100k images, randomly stratified partitionied using sklearn’s train_test_split
into 80% training / 20% validation data. In Keras, this yields the following losses and accuracies during the first ten epochs (i.e. running loss and accuracy during each epoch)
loss: 0.1162 - accuracy: 0.9580
loss: 0.0507 - accuracy: 0.9823
loss: 0.0402 - accuracy: 0.9858
loss: 0.0345 - accuracy: 0.9880
loss: 0.0304 - accuracy: 0.9891
loss: 0.0276 - accuracy: 0.9904
loss: 0.0251 - accuracy: 0.9909
loss: 0.0242 - accuracy: 0.9914
loss: 0.0229 - accuracy: 0.9917
loss: 0.0213 - accuracy: 0.9924
whereas Torch yields (step outputs removed):
[acc=0.934, loss=0.134]
[acc=0.975, loss=0.050]
[acc=0.981, loss=0.0386]
[acc=0.982, loss=0.0332]
[acc=0.984, loss=0.0299]
[acc=0.985, loss=0.0285]
[acc=0.986, loss=0.0262]
[acc=0.987, loss=0.024]
[acc=0.987, loss=0.0236]
[acc=0.988, loss=0.0225]
Obviously, the task is solved well in both cases, but Keras yields about 0.5% higher accuracy. The same holds for the validation error after 20-40 epochs. Keras yields ~99.5 % validation accuracy, Torch ~99 % “only”. I mean, 99 % accuracy is still fine, but nevertheless 99.5 % is half the error rate, so more than random. Most importantly, these results are reproducible even with repeated trainings, i.e. randomization effects exist (train / test split, image shuffling during epoch, etc.), still the loss / accuracy results are relatively stable. In particular, Keras keeps outperforming Torch by 0.5 % validation accuracy (99.5 % vs. 99.0 %).
Furthermore, even using Torch’s default initialization (i.e. removing explicit Xavier uniform and zeros initialization of weights and biases) does not change anything here. To me, there is something significantly different during training with Keras and Torch. I read that Keras applies an internal learning rate decay during each epoch (cf. Keras learning rate schedules and decay - PyImageSearch) - not to confuse with a learning rate scheduler after each epoch that easily can be implemented in both frameworks - I suspect that this is one of the most significant differences. Or does Torch perform the same learning rate updates during each epoch? And if not, how can this ideally be implemented in Torch? Or are there any other significant differences one should be aware of?
Appreciating any comment and thanks in advance