A dataloader for multiple similar inputs

I’m not sure, how you are loading the data, but have a look at this code example, which should yield the right shapes:

class MyDataset(Dataset):
    def __init__(self):
        pass
    
    def __getitem__(self, index):
        image1 = torch.randn(3, 224, 224)
        image2 = torch.randn(3, 224, 224)
        image3 = torch.randn(3, 224, 224)
        target = torch.randint(0, 10, (1, ))
        return image1, image2, image3, target

    def __len__(self):
        return 128


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        pass
    
    def forward(self, image1, image2, image3):
        print('image1 shape', image1.shape)
        print('image2 shape', image2.shape)
        print('image3 shape', image3.shape)
        x = image1
        y = image2
        z = image3
        return x, y, z


model = MyModel()
dataset = MyDataset()
loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True
)

for *data, target in loader:
    print(type(data))
    outputs = model(*data)

> <class 'list'>
image1 shape torch.Size([64, 3, 224, 224])
image2 shape torch.Size([64, 3, 224, 224])
image3 shape torch.Size([64, 3, 224, 224])
<class 'list'>
image1 shape torch.Size([64, 3, 224, 224])
image2 shape torch.Size([64, 3, 224, 224])
image3 shape torch.Size([64, 3, 224, 224])

Did not understand this part

I created random data. You should of course load your data according to your use case.
The code snippet just gives the general workflow using some randomly initialized tensors.

Hi @ptrblck . I fixed the shape mismatch problem, it was due to a bug in the code.

When I’m creating the dataest, I can check that the dataset outputs three images. But when I wrap it in a dataloader with a batch size more than 1, I get an index error only when shuffle=True.

class MyDataset(Dataset):
    
    def __init__(self, video_path, transform, nb_frames=3):
        self.nb_frames = nb_frames
        self.video_path = video_path
        classes, class_to_idx = self._find_classes(self.video_path)
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        
        
        folders = []
        for target in sorted(class_to_idx.keys()):
            d = os.path.join(video_path,target)
            if not os.path.isdir(d):
                continue
            for folder in sorted(os.listdir(d)):
                folders.append([d + "/" + folder,class_to_idx[target]])
        
        # The folders array is of the format [['dir/train/fake/001_003',0],['dir/train/real/001',1], ... etc]
        # Now we have to store the individual images in order to feed them into the network 
        
        self.samples = []
        for f in folders:
            fold = sorted(os.listdir(f[0]))           #f[0] is the name of the folder, f[1] is the target
            idx = f[1] 
            # Crop data to multiple of nb_frames
            fold1 = fold[:-(len(fold)%self.nb_frames)] if len(fold)%self.nb_frames!=0 else fold
#             print([fold1,idx])
#             break
            for file in range(len(fold1)):
                fold1[file] = f[0] + "/" +  fold1[file]
            
            self.samples.append([fold1,idx])
        
        #self.samples now contain data in the form of - [['../001/1.jpg','../001/2.jpg','001/3.jpg',...],1],[['../001_003/1.jpg',...],0]]
        
        # calculate lengths
        self.lens = [len(d[0])//self.nb_frames for d in self.samples]
        # calculate offsets
        self.offsets = np.concatenate(([0], np.cumsum(self.lens[:-1])))
    
    def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        Args:  dir (string): Root directory path.
        Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

        Ensures:
            No class is a subdirectory of another.
        """
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx
    
    def pil_loader(self, path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    
    def transform_self(self, image):
        # Resize
        resize = transforms.Resize(size=(224, 224))
        image = resize(image)
        
        # Transform to tensor
        image = TF.to_tensor(image)
        
        return image

    
    def __getitem__(self, index):
        # subtract offset
        print('index: {}'.format(index))
        # get corresponding video file
        found = False
        for i, offset in enumerate(self.offsets):
            if index < offset:
                #print('subtracting {} from index'.format(self.offsets[i-1]))
                index -= self.offsets[i-1]
                index *= self.nb_frames
                found = True
                break
        # handle last video separately
        if not found:
            index -= self.offsets[-1]
            index *= self.nb_frames
            i += 1

        # select correspondind data
        print('selecing video {}'.format(i-1))
        samples = self.samples[i-1][0]
        target = self.samples[i-1][1]   # for a particular index, the target remains the same for the three images
        #print(target)
        
        # get frames
        print('reading frames {}'.format([idx for idx in range(index, index+self.nb_frames)]))
        x = []
        for idx in range(index, index+self.nb_frames):
            path = samples[idx]
            tmp = self.pil_loader(path)

            if self.transform is not None :
                tmp = self.transform(tmp)
            
#             tmp = self.transform_self(tmp)
             
            x.append(tmp)
        
        #print(len(x))
        #x = torch.cat(x)
        #print(x.shape)
        return x[0],x[1],x[2],target
        
    def __len__(self):
        return np.sum(self.lens)*self.nb_frames

This is the dataset class.

image_datasets = {x: MyDataset(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val', 'test']}

Here, when I do

for data in image_datasets['train']:
    print(data[0].shape,data[1].shape,data[2].shape)
    

#this gives the perfectly okay result -
index: 0
selecing video 0
reading frames [0, 1, 2]
torch.Size([3, 224, 224]) torch.Size([3, 224, 224]) torch.Size([3, 224, 224])
#and so on

But when I wrap it into the dataloader -

loader = DataLoader(
    image_datasets['train'],
    batch_size=64,
    shuffle=True
)

for *data, target in loader:
    print(target)

#this gives the output - 
index: 30630
selecing video 589
reading frames [35556, 35557, 35558]
#and then the error - 
   104         x = []
    105         for idx in range(index, index+self.nb_frames):
--> 106             path = samples[idx]
    107             tmp = self.pil_loader(path)
    108 

IndexError: list index out of range

I can’t understand why the dataloader shows the index error only shuffle = True