Collate Function for Dataloader

Hello Members,
I am stuck and need an assist !

My batches from the dataset needs to be of uniform size while enumerating the dataloader. I have created a collate function, too ! Just that when I pass this custom collate function, the output of dataloader generator function is yielded as None.

Seems like a silly mistake but I can’t figure it out.

Collate function’s intention is to pad with max(width) and max(height) for variable size input in batches.

1. Collator Function

def DimensionCollator(object):
    def __call__(self,batch):
        print(self.batch)
        max_width =  max([item['img'].shape[1] for item in batch])
        max_height = max([item['img'].shape[2] for item in batch])
        imgs = torch.ones([len(batch), batch[0]['img'].shape[0],max_width,max_height], dtype=torch.float32)

        for idx, item in enumerate(batch):
            try:
                imgs[idx,:,0:item['img'].shape[1],0:item['img'].shape[2]] = item['img']
            except:
                print(imgs.shape)
        
        item = {'img':imgs,'label':label}
        
        return item

2. Dataset Class

# 1. Creating Dataset
class TextRecogDataset(Dataset):
    def __init__(self,subset) :
        super(TextRecogDataset,self).__init__()
        self.subset = subset
        if self.subset=='train':
            self.df = pd.read_csv('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/written_name_train.csv')
        else:
            self.df = pd.read_csv('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/written_name_validation.csv')
        transform = [transforms.Grayscale(1), transforms.ToTensor(),transforms.Normalize((0.5),(0.5))]
        self.transform = transforms.Compose(transform)
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        self.image_path = self.df['FILENAME'][index]
        self.label = self.df['IDENTITY'][index]
        #reading the image, applying trasnform
        if self.subset=='train':
            img = Image.open('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/train/'+self.image_path)
        else:
            img = Image.open('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/val/'+self.image_path)
        if self.transform is not None:
            img = self.transform(img)
        
        item = {'img':img, 'label':self.label}
        
        return item

3. Output while enumerating the data loader without passing the collate_fn and increasing batch_size>1

for i, train_batch in enumerate(train_loader):
    print(train_batch['img'].shape, train_batch['label'])

stdout
torch.Size([1, 1, 23, 388]) [‘TANGUY’]
torch.Size([1, 1, 31, 284]) [‘MOBISA’]
torch.Size([1, 1, 29, 388]) [‘FLAVIO’]
torch.Size([1, 1, 36, 388]) [‘DALLONGEVILLE’]
torch.Size([1, 1, 45, 284]) [‘REYNARD’]

Thanks in Advance !

Could you describe how you are using your DimensionCollator with the dataloader? It seems that you might have wanted to write

class DimensionCollator
...

rather than

def DimensionCollator
...
1 Like

Aye, Thanks !
My bad that I defined collate_fn as function when it should have been a class. Silly mistake !

It worked ! :smiley: