Batch normalization when test stationarity slightly differs

My data is financial time-series where the stationarity properties of the train set will always be slightly different from the test set. I’m using an LSTM model that uses batch normalization on the final linear layers (LSTM → dropout → batch norm → linear → relu …). I noticed that my test results were way better once I didn’t use model.eval() and recognized that it’s due to the batch normalization layers’ running variances and means not adjusting to the test data. I understand that adjusting the running stats during evaluation would be a data leak, especially in the case of sequential data where the batch would essentially run calculations on future datapoints. I solved this by using batches that have items ordered by time and I only evaluate the last item and therefore the running stats were calculated based on previous data only. I set the batch normalization layers’ momentum to 1 (so that it calculates the means and variances only from the current batch) and evaluate only the last item of the batch. Of course the evaluation takes significantly longer in this case, since I run a forward pass on the whole batch while only evaluating a single item and using a sliding window for the whole dataset. Due to the very long evaluation time, I update the running stats only every 16 data points in the following way:

  1. Set the batch normalization layers to train() and perform a forward pass on a batch (512) and evaluate only the last item.
  2. Set the batch normalization layers to eval() and evaluate a smaller batch (16). Repeat 1.

This greatly improves the test results while not taking too long to calculate. Did I take the right approach?

I think your understanding is correct and it often comes to the question how this model would be deployed into a real-world use case. I.e. I assume you won’t have future data points when the model is deployed, so you also shouldn’t use it during validation.

That sounds reasonable.

Alternatively, you could also use track_running_stats=False, which would then calculate the mean and std always from the input activation.

I think it boils down again to the test use case. If you are not leaking any future data (or any other data you wouldn’t have), your approach could be fine.

1 Like

Thanks for the response, it also made me realize that it’s not entirely about stationarity, but rather about the fact that my production setting requires single sequential real-time predictions.

Even though this improved my results, I’ve now realized that the better option for my use case is group normalization (GroupNorm) and not using batch normalization at all. Gives better results and no gimmicks required.