Creating a dataset that returns two images efficiently

I am working on a project where I have to simultaneously load an image along with the features of the image. I have pre-computed the features (single channel, same dimensions as the image) and they’re present on the same disk with the same name and structure as the source files.
The idea is to load the images simultaneously, perform the same preprocessing on both (rotation/cropping/affine), and normalize only the source images.
I have tried to build a dataset, but it is way too slow and often blocks the dataloader. The code is something like this:

class customdataset(torchvision.datasets.ImageFolder):
    def __init__(self,
                 src,
                 feat_src):
        super().__init__(src)
        self.src = src
        self.feat_src= feat_src
        
        self.transformation = [transformation,
                                          normalization]
    def __getitem__(self, idx):
        img_path, label = self.imgs[idx]
        feat_path = img_path.replace(self.src, self.feat_src)
        
        # open imgs
        img = self.loader(img_path)
        feat= self.loader(feat_path).convert("L")
        
        img = # convert image to tensor
        feat= # convert features to tensor
        
        combo_img = torch.cat([img, feat], axis = 0)   # concatenate images 
        combo_img = self.transforms(combo_img )    # apply transformation
            
        img = combo_img [:-1]     # isolate images
        feat= combo_img [-1]       # isolate features
    
        
        return img, feat, label

Right now, it’s blocking a lot with anything other than a PCIe NVMe 4.0 SSD (tried on cloud and remote servers with beefy GPU and CPU configs, barely could do 100 steps of the dataloader in 1 minute). I know that this can be much faster, as single-image loops using ImageFolder run at ~250 dataloader steps in 1 min.

I have tried a lot of things to no avail. I want this code to be as efficient as possible as the training loop will run for ~2-4 days. Could anyone please suggest how I can make it more efficient? Thanks in advance.

I don’t know which setup you are using currently but if the data loading bandwidth is the current bottleneck your options might be limited to speed up the code besides obviously upgrading your storage. You could try to increase the number of workers hopefully allowing them to prefetch the data fast enough.

1 Like