Get batch from collate_fn of Pytorch data loader

I am reading video and split them into batched. I cant fit a whole video in my batch. So i split them using collate function. Shape of video is (128,3,224,224).
Where 128 is frames, and 3 is channels (RGB).
Output shape from data loader after using collate function is
torch.Size([4, 32, 3, 224, 224]) where index 0 is batch. I if pass this bath data , if will get memory error. Is there a way to send a batch of 2 .?
Here is the code for reproduce-ability purpose.
Generate dummy data

#here the data is in ram, but actually its in disk and i cant load all data to ram to manipulate. 
for i in range(5):
    features.append(np.random.randint(low=0, high=256, size=(128,3,224,224)))
    labels.append(np.random.randint(low=0, high=2))


class DataReader(Dataset):
    def __init__(self,features,labels):
        self.features = features
        self.labels = labels

    def __getitem__(self, index):
        return x,y
    def __len__(self):
        return len(self.features)


def reshaped(i,bs=32):
    return i
def collate_fn(batch):
    return data,label
dl=DataLoader(DataReader(features,labels), batch_size =None, num_workers=0,pin_memory=True,shuffle=False,collate_fn=collate_fn)

batch[0] shape is (4, 32, 3, 224, 224) even with a single batch the generated data is large, so i want to send (2, 32, 3, 224, 224) then another (2, 32, 3, 224, 224)

Dont get confused with batch=None, i am just disabling automatic batching. Practically batch=None load one batch. I convert that one batch to four batches in collate function.
But the issue is that now all four will be sent further and i want only 2 of them to be sent, then next 2 to the model, instead of all four at once.

Double post from here.