Idiomatic way to standardize data by blocks (days)

I have a dataset that consists of samples recorded on different days, i.e. ~100 days, variable number of samples (~10K - 100K) per day.

Given the nonstationary nature of the data, it makes sense for us to standardize the data per day. Standardized-by-day data is used for training, for inference we store mean of all daily means and stds and standardize inference inputs using those aggregates (which also makes sense in our domain, there is no consistent trend in the data, but there are daily fluctuations in mean and std).

The current pipeline is the following:

Training:

  • Load the data by day
  • On day i subtract mean and divide by std of all samples on that day
  • Store mean of daily means and mean of daily stds (for inference)
  • Concatenate all standardized samples over all days into a single tensor
  • Create torch.utils.data.Dataset and torch.utils.data.DataLoader which will shuffle the data and split into batches
  • Train the model

Online inference:

  • On every input sample subtract stored mean and divide by stored std
  • Evaluate prediction

This works for us but the solution is not the most elegant, now there are two different code paths of getting data into the model (training and inference), training data is pre-processed while inference samples are passed raw.
Ideally we would like to do per-day standardization in the torch.utils.data.Dataset or torch.utils.DataLoader based on train/inference flag (train will estimate and store mean/std from the data, inference will use this values) or some other idiomatic alternative (maybe nn.Module that would be called inside forward call or torchvision.transformer-like?). We compile our models via TorchScript for online inference outside of Python environment, so it would be great to have the standardization step compiled as well rather than have it outside of the pipeline. I hope there is an idiomatic solution that would let us avoid different code paths in inference and training.

It is not immediately obvious how to do this in torch.utils.DataLoader (or any step after this) since it splits data into random batches (there is no way of getting daily mean or even day corresponding to the sample at this point) and there doesn’t seem to exist any solution for by-group (by-day) data transforms (standardization).

Any ideas on the best practices to do this?