My data is financial time-series where the stationarity properties of the train set will always be slightly different from the test set. I’m using an LSTM model that uses batch normalization on the final linear layers (LSTM → dropout → batch norm → linear → relu …). I noticed that my test results were way better once I didn’t use model.eval()
and recognized that it’s due to the batch normalization layers’ running variances and means not adjusting to the test data. I understand that adjusting the running stats during evaluation would be a data leak, especially in the case of sequential data where the batch would essentially run calculations on future datapoints. I solved this by using batches that have items ordered by time and I only evaluate the last item and therefore the running stats were calculated based on previous data only. I set the batch normalization layers’ momentum to 1 (so that it calculates the means and variances only from the current batch) and evaluate only the last item of the batch. Of course the evaluation takes significantly longer in this case, since I run a forward pass on the whole batch while only evaluating a single item and using a sliding window for the whole dataset. Due to the very long evaluation time, I update the running stats only every 16 data points in the following way:
- Set the batch normalization layers to train() and perform a forward pass on a batch (512) and evaluate only the last item.
- Set the batch normalization layers to eval() and evaluate a smaller batch (16). Repeat 1.
This greatly improves the test results while not taking too long to calculate. Did I take the right approach?