For the second use case you could use Tensor.unfold
:
S = 128 # channel dim
W = 256 # width
H = 256 # height
batch_size = 10
x = torch.randn(batch_size, S, W, H)
size = 64 # patch size
stride = 64 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
> torch.Size([10, 2, 4, 4, 64, 64, 64])
patches
now containes [2, 4, 4]
patches of size [64, 64, 64]
.
For the first use case:
You could use pooling operators like nn.MaxPool3d
and nn.AvgPool3d
.
Here is an example using the functional API:
import torch.nn.functional as F
x_half = F.max_pool3d(Variable(x), kernel_size=2, stride=2)