Get batch from collate_fn of Pytorch data loader

I am reading video and split them into batched. I cant fit a whole video in my batch. So i split them using collate function. Shape of video is (128,3,224,224).
Where 128 is frames, and 3 is channels (RGB).
Output shape from data loader after using collate function is
torch.Size([4, 32, 3, 224, 224]) where index 0 is batch. I if pass this bath data , if will get memory error. Is there a way to send a batch of 2 .?
Here is the code for reproduce-ability purpose.
Generate dummy data

features=[]
labels=[]
#here the data is in ram, but actually its in disk and i cant load all data to ram to manipulate. 
for i in range(5):
    features.append(np.random.randint(low=0, high=256, size=(128,3,224,224)))
    labels.append(np.random.randint(low=0, high=2))

dataset

class DataReader(Dataset):
    def __init__(self,features,labels):
        self.features = features
        self.labels = labels

    def __getitem__(self, index):
        x=self.features[index]
        y=self.labels[index]
        return x,y
        
    def __len__(self):
        return len(self.features)

dataloader

def reshaped(i,bs=32):
    i=i[len(i)%bs::].reshape(len(i)//bs,-1,3,224,224)
    return i
def collate_fn(batch):
    data=reshaped(batch[0])
    print(data.shape)
    label=[batch[1]]*len(data)
    return data,label
dl=DataLoader(DataReader(features,labels), batch_size =None, num_workers=0,pin_memory=True,shuffle=False,collate_fn=collate_fn)
batch=next(iter(dl))

batch[0] shape is (4, 32, 3, 224, 224) even with a single batch the generated data is large, so i want to send (2, 32, 3, 224, 224) then another (2, 32, 3, 224, 224)

Dont get confused with batch=None, i am just disabling automatic batching. Practically batch=None load one batch. I convert that one batch to four batches in collate function.
But the issue is that now all four will be sent further and i want only 2 of them to be sent, then next 2 to the model, instead of all four at once.

Double post from here.