How to call a function on the Dataset for every batch in Dataloader?

I am using a model where I want to initialize some random state in my Dataset that decides how many context points to add to the current instance (from this paper:

I can’t figure out how to call a function on each batch. The only requirements I have is that

  1. It can be seen by the Dataset object so I can use it in its __getitem__ method.
  2. It will be called once for every batch.

Is there a way to accomplish this?

You could probably create an internal attribute in your Dataset's __init__ and manipulate this attribute inside the training loop via: loader.dataset.my_flag = my_value.
This .my_flag could be checked in __getitem__.

However, if I’m not mistaken, this will only work using a single process.
For multiple workers in your DataLoader, a shared array could probably work as the flag as given in this small example.