I am using a model where I want to initialize some random state in my Dataset that decides how many context points to add to the current instance (from this paper: https://arxiv.org/pdf/1807.01613.pdf)
I can’t figure out how to call a function on each batch. The only requirements I have is that
It can be seen by the Dataset object so I can use it in its __getitem__ method.
You could probably create an internal attribute in your Dataset's __init__ and manipulate this attribute inside the training loop via: loader.dataset.my_flag = my_value.
This .my_flag could be checked in __getitem__.
However, if I’m not mistaken, this will only work using a single process.
For multiple workers in your DataLoader, a shared array could probably work as the flag as given in this small example.