How to extract smaller image patches (3D)?

For the second use case you could use Tensor.unfold:

S = 128 # channel dim
W = 256 # width
H = 256 # height
batch_size = 10

x = torch.randn(batch_size, S, W, H)

size = 64 # patch size
stride = 64 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
> torch.Size([10, 2, 4, 4, 64, 64, 64])

patches now containes [2, 4, 4] patches of size [64, 64, 64].

For the first use case:
You could use pooling operators like nn.MaxPool3d and nn.AvgPool3d.

Here is an example using the functional API:

import torch.nn.functional as F
x_half = F.max_pool3d(Variable(x), kernel_size=2, stride=2)
13 Likes