Keras on_epoch_end functionality

I want to call a function defined in my dataset class at the end of every epoch of training. I’m not sure if it’s the right thing to do and I wanted some feedback. The current structure looks something like below:

class my_dataset(nn.data.utils.Dataset):
    __init__(self, ...)
    __len__(self)
    __getitem__(self, idx)
    my_func(self)
data = my_dataset()
data_loader = torch.utils.data.DataLoader(data)
def training():
    for i, sample in enumerate(train_loader):
        do something
    data.my_func()

I know that I can call data.my_func() but will it be reflected in the data_loader ? I guess I’m trying to copy the on_epoch_end function provided in the Keras Dataset structure. Is this the right way to do it ?
Thanks !

2 Likes

What would you like to achieve?
You could also call train_loader.dataset.my_func(). Both ways should work though.

Ignite provides some event handling like @trainer.on(Events.EPOCH_COMPLETED). Maybe the code would be cleaner, if you need a lot on these events. :wink:

Thank you for the reply. At a high level, I’m just looking at shuffling data. I know that there is randomization in the order of samples selected (with shuffle=True in the data loader) via the sampling procedure but due to how I implement __getitem__, there is a likely correlation between the mini-batch of images loaded. A random shuffling of my image set after every epoch would break this correlation.

I’ll look into Ignite. I did come across Torchsample but it wasn’t obvious how to do what I wanted in it the way it seems to be in Ignite. But I think the native pytorch solution of just calling it after the iterator of train_loader completes is what I’ll go with, for now.

Thanks again for your help !

May I ask, how you are shuffling the data in the Dataset?
Since the __getitem__ function receives an index, I assume you are shuffling by copying your data?
Maybe you could implement your own Sampler, which could be faster?

I have the following two variables in my dataset which is a list of image names per object category: self.list_image_names and self.num_images. The __getitem__ does something like this
image_names = [self.image_names[i][idx % self.num_images[i]] for i in range(len(self.list_image_names))]. The order in which my model sees the images changes due to the default sampler but the images inside a particular batch are the same every time (for a given idx). I plan to shuffle data at the end of every epoch using self.image_names[i] = np.random.permutation(self.image_names[i]) to break this correlation.

Implementing my own Sampler might be faster or the correct thing to do, but I think this shuffling will do what I want.

Yeah, I would do the same, since you are just shuffling names/paths it’s not a big deal.
I had my concerns with shuffling (large) data, but that seems just be fine. :wink: