How would you like to create these patches?
If you would like to create them in a non-overlapping way, you would end up with
(172//32) * (220//32) * (156//32) = 120
patches.
Could you explain, how these 500 patches should be created.
For the non-overlapping case you could load the image, use unfold
to create the patches and return them in your __getitem__
:
x = torch.randn(172, 220, 156)
patches = x.unfold(2, 32, 32).unfold(1, 32, 32).unfold(0, 32, 32)
patches = patches.contiguous().view(-1, 32, 32, 32)
print(patches.shape)
> torch.Size([120, 32, 32, 32])
The code you’ve provided for the 2D case doesn’t seem to create patches or am I missing something?