Efficiently slicing tensor like a convolution?

tensor.unfold should yield the desired output:

B, C, H, W = 2, 3, 4, 4
x = torch.arange(B*C*H*W).view(B, C, H, W)

kernel_h, kernel_w = 2, 2
stride = 2

patches = x.unfold(2, kernel_h, stride).unfold(3, kernel_w, stride)
print(patches )
3 Likes