How do I code dataset to extract patches in 3D image

I have a question about the extraction of patches in 3D images.
Say I have many 3D images with different resolutions and for each one, I would extract different numbers of fixed size patches(404040). How should I code the getitem so that the dataloader would extract patches one image by one image and the I can get (batch_size, channel, 40, 40, 40) when enumerating dataloader.