How to load n amount of sequential frames in batches?

Hello,

I have the following setup for my data

data/
       - 0000/
                - 0000.png
                - 00001.png
                - ...
       - 0001/
                - 0000.png
                - ...
       - ...

I’m trying to load 5 sequential frames so that one batch would return the following:

batch0:
- 0000/0000.png, 0000/0001.png, ..., 0000/N.png
- 0000/0001.png, 0000/0002.png, ..., 0000/N+1.png

I’m quite new to PyTorch and I feel a bit lost in different alternatives. I’ve looked at the following threads, but haven’t really found anything that hits home:

  1. Can we sample multiple images all together?
  2. How upload sequence of image on video-classification

In the end I want to stack the images in a tensor on top of each other, to the dimension (WxHxDxN). Can I combine the dataset’s __getitem__ and the sampler to make sure that the first sample contains the N first images and the last sample containts the N last images? My spontaneous feeling is that I should yield the tensor in __getitem__, thus being able to easily create batches of “sequences” as the sequences are already stacked, but how can one create magic with restricting the indices?.

.@ptrblck answers in the first and second link are definitely interesting, and I’ve looked at using SequentialSampler/BatchSampler, but in this case it doesn’t solve the problem of “starting” at the Nth image.

Thanks,

Erik

1 Like

Also since the different directories are not overlapping, I found that ConcatDataset might do the trick to easily train from the different directories. (from this stackoverflow thread).

I think I found a solution!

class MySampler(Sampler):
    """
    Sampler that does not screw around with taking too early/late indices.
    """
    def __init__(self, data_source, input_config):
        self.config = input_config
        self.data_source = data_source
        super().__init__(data_source)

    def __iter__(self):
        return iter([i for i in range(len(self.data_source)) if i >= self.config.num_conseq_frames - 1 and
                    len(self.data_source) - i > self.config.num_conseq_frames])

    def __len__(self):
        len(self.data_source)