Control batch size in collate function

I am reading video and split them into batched. I cant fit a whole video in my batch. So i split them using collate function. Shape of video is (9*32,3,224,224).

Output shape from data loader after using collate function is
torch.Size([9, 32, 3, 224, 224]) where index 0 is batch. I if pass this bath data , if will get memory error. Is there a way to send a batch of 3 .?

Here is code

dl=DataLoader(DataReader(df,train_aug), batch_size =1,collate_fn=collate_fn)

here is custom dataset

class DataReader(Dataset):
    def __init__(self, dataset, transform=None):
        self.df = dataset
        self.transform = transform

    def __getitem__(self, index):
        x=self.df['path'][index]
        y=self.df['pulse'][index]
        x=readframes(x)
        x1=[]
        for i in x:
            img=cv2.cvtColor(src=i, code=cv2.COLOR_BGR2RGB)
            img=train_aug(image=img)['image']
            x1.append(img)
        return torch.stack(x1),np.float32(y)
        
    def __len__(self):
        return len(self.df)

readframes is a function to extracrt frames of a video.

def reshaped(i,bs=32):
    i=i[len(i)%bs::].reshape(len(i)//bs,-1,3,224,224)
    return i
def collate_fn(batch):#batch is a tuple. where first value is batch.,2nd is ether array or label
    bs=32
    data=np.concatenate([reshaped(i[0]) for i in batch])  
    label=np.concatenate([[i[1]]*(len(i[0])//bs) for i in batch])
    data,label=data[0:4],label[0:4]
    return torch.tensor(data,dtype=torch.float),torch.tensor(label)

The collate_fn will get all loaded samples and will create the batch. Removing already loaded and processed samples that late sounds a bit wasteful. Wouldn’t it work to not load the undesired frames in the first place?

my gpu memory (8GB) does not fit even a single batch, so i want to split via collate function, and convert them to sub batches and pass the sub batches.

I guess your GPU is running out of memory during the model training, not while creating the batch.
If so, you could split the batch into chunks in the DataLoader loop and perform multiple training steps.
Alternatively, you could only load the desired number of samples in the Dataset and create a batch with the smaller number of samples. Removing samples in the collate_fn could work, but would be wasteful as previously described.