Update Dataset class attribute while data is being loaded

I want to update my dataset class attributes after each epoch (actually during each epoch), which would affect the __getitem__ function in next epoch.
I realized that when we have multiple threads, separate dataset instances are created (or something like that) and so I wont have uniform attribute of all instances.
Is there an easy way to do this???
(Please ask if question is not clear :sweat_smile:)

What kind of dataset attributes do you want to update?
It could be problematic to change the underlying Dataset when using multiprocessing.

Hii, I built a custom data class for video dataset (I know about nvvl but can’t use it rn).
Different videos may have different number of frames. During validation I slide across videos in following manner ->

First 120 frames from all, next 120 frames from all and so one. But once a this slides window covers the smallest video entirely as I cannot return 0 dimensional tensor, I have to send 8 frames from that video again and again until sliding window gets over for the largest video. This means that I validate of same set of frames (from smaller videos multiple times).

So I thought of storing attributes containing which video has finished and to ignore such cases. But it doesnt work thanks to multiprocessing.

Any suggestions, (I AM OPEN TO COMPLETELY DIFFERENT APPROACH AS WELL)

You could try to return the frames from each video as long as it’s possible and stack the different frames into the batch dimension.
This would change your actual batch size, but could work.
I’ve created a small example:

class VideoDataset(Dataset):
    def __init__(self):
        self.video1 = torch.randn(10, 3, 24, 24)
        self.video2 = torch.randn(20, 3, 24, 24)
        
    def __getitem__(self, index):
        if index < len(self.video1):
            x1 = self.video1[index]
        else:
            x1 = None

        if index < len(self.video2):
            x2 = self.video2[index]
        else:
            x2 = None
        
        x = torch.cat(([x.unsqueeze(0) for x in [x1, x2] if not x is None]))
        return x
        
    def __len__(self):
        return len(self.video2) # longest video

    
dataset = VideoDataset()

loader = DataLoader(
    dataset,
    batch_size=2,
    num_workers=2
)

for data in loader:
    print('Before ', data.shape)
    data = data.view(-1, 3, 24, 24)
    print('After ', data.shape)

Note that you would have to implement your frame reading operation instead of just slicing the tensor.

Nice! but a batch size of 1 and 120 frames in temporal dimension is my memory limit (10 GB), so I will have to iterate sequentially.
And my videos are quite big!! upto 4500 frames !