I am reading video and split them into batched. I cant fit a whole video in my batch. So i split them using collate function. Shape of video is (9*32,3,224,224)
.
Output shape from data loader after using collate function is
torch.Size([9, 32, 3, 224, 224])
where index 0 is batch. I if pass this bath data , if will get memory error. Is there a way to send a batch of 3 .?
Here is code
dl=DataLoader(DataReader(df,train_aug), batch_size =1,collate_fn=collate_fn)
here is custom dataset
class DataReader(Dataset):
def __init__(self, dataset, transform=None):
self.df = dataset
self.transform = transform
def __getitem__(self, index):
x=self.df['path'][index]
y=self.df['pulse'][index]
x=readframes(x)
x1=[]
for i in x:
img=cv2.cvtColor(src=i, code=cv2.COLOR_BGR2RGB)
img=train_aug(image=img)['image']
x1.append(img)
return torch.stack(x1),np.float32(y)
def __len__(self):
return len(self.df)
readframes
is a function to extracrt frames of a video.
def reshaped(i,bs=32):
i=i[len(i)%bs::].reshape(len(i)//bs,-1,3,224,224)
return i
def collate_fn(batch):#batch is a tuple. where first value is batch.,2nd is ether array or label
bs=32
data=np.concatenate([reshaped(i[0]) for i in batch])
label=np.concatenate([[i[1]]*(len(i[0])//bs) for i in batch])
data,label=data[0:4],label[0:4]
return torch.tensor(data,dtype=torch.float),torch.tensor(label)