What is updated *per-epoch* (and not *per-batch*)?

I changed my script to use pytorch lightning and now I do not get the same results/performance.

Before: I had a hand-made torch script with a classic training nested loops (over epochs and over batches in each epoch).

After: litterally copied and pasted the things in appropriate places to fit PL’s framework.

The catch: I was defining an epoch as 10 iterations over my training dataset (just replicate the instances in the dataloader 10x), now I define an epoch as a single iteration over the training epoch (no replication) but I multiplied the number of epochs (50) by 10x (so 500) – therefore I have the same number of batches seen during the training.

One solved problem: as my learning rate was decaying by a ration 0 < r < 1 (so LR_{k+1} = LR_{k} * r where k is the epoch index), the final learning rate was much smaller than before, so I replaced it by an equivalent rate decay (r' = r ^ (50/500)) such that the learning rate at the begining and at end of the training are the same. The performance had some improvement but I’m still far from the previous one.

The question: are there other parameters or hyperparameters that are updated per-epoch (and not per-batch)?


PS1: I’m admitting that the problem is comming from some per-epoch update because I reviewed the code and everything should be the same except for the number of epochs although the number of batches is still the same.

PS2: a few things in my training recipe: optimizer SGD with nesterov and momentum, weight decay regularization, model with batchnorm, 2d convolutions, and maxpooling.

2 Likes

PyTorch won’t update anything in your model behind your back per-iteration or per-epoch, so I would recommend to look for explicit per-epoch updates such as learning rate decay, change in data augmentation etc.
Also, usually you would run a validation loop after one training epoch. Make sure to call model.eval() in the validation loop to avoid updating the running stats of batchnorm layers etc.

PyTorch won’t update anything in your model behind your back per-iteration or per-epoch, so I would recommend to look for explicit per-epoch updates such as learning rate decay, change in data augmentation etc.

Got it.

Make sure to call model.eval() in the validation loop to avoid updating the running stats of batchnorm layers etc.

I’m using pytorch-lightning, which is supposed to manage that under the hood.

Thanks!