Hello Members,
I am stuck and need an assist !
My batches from the dataset needs to be of uniform size while enumerating the dataloader. I have created a collate function, too ! Just that when I pass this custom collate function, the output of dataloader generator function is yielded as None.
Seems like a silly mistake but I can’t figure it out.
Collate function’s intention is to pad with max(width) and max(height) for variable size input in batches.
1. Collator Function
def DimensionCollator(object):
def __call__(self,batch):
print(self.batch)
max_width = max([item['img'].shape[1] for item in batch])
max_height = max([item['img'].shape[2] for item in batch])
imgs = torch.ones([len(batch), batch[0]['img'].shape[0],max_width,max_height], dtype=torch.float32)
for idx, item in enumerate(batch):
try:
imgs[idx,:,0:item['img'].shape[1],0:item['img'].shape[2]] = item['img']
except:
print(imgs.shape)
item = {'img':imgs,'label':label}
return item
2. Dataset Class
# 1. Creating Dataset
class TextRecogDataset(Dataset):
def __init__(self,subset) :
super(TextRecogDataset,self).__init__()
self.subset = subset
if self.subset=='train':
self.df = pd.read_csv('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/written_name_train.csv')
else:
self.df = pd.read_csv('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/written_name_validation.csv')
transform = [transforms.Grayscale(1), transforms.ToTensor(),transforms.Normalize((0.5),(0.5))]
self.transform = transforms.Compose(transform)
def __len__(self):
return len(self.df)
def __getitem__(self, index):
self.image_path = self.df['FILENAME'][index]
self.label = self.df['IDENTITY'][index]
#reading the image, applying trasnform
if self.subset=='train':
img = Image.open('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/train/'+self.image_path)
else:
img = Image.open('/home/suraj/ClickUp/Jan-Feb/data/ocr_data/val/'+self.image_path)
if self.transform is not None:
img = self.transform(img)
item = {'img':img, 'label':self.label}
return item
3. Output while enumerating the data loader without passing the collate_fn and increasing batch_size>1
for i, train_batch in enumerate(train_loader):
print(train_batch['img'].shape, train_batch['label'])
stdout
torch.Size([1, 1, 23, 388]) [‘TANGUY’]
torch.Size([1, 1, 31, 284]) [‘MOBISA’]
torch.Size([1, 1, 29, 388]) [‘FLAVIO’]
torch.Size([1, 1, 36, 388]) [‘DALLONGEVILLE’]
torch.Size([1, 1, 45, 284]) [‘REYNARD’]
Thanks in Advance !