tensor.unfold
should yield the desired output:
B, C, H, W = 2, 3, 4, 4
x = torch.arange(B*C*H*W).view(B, C, H, W)
kernel_h, kernel_w = 2, 2
stride = 2
patches = x.unfold(2, kernel_h, stride).unfold(3, kernel_w, stride)
print(patches )
tensor.unfold
should yield the desired output:
B, C, H, W = 2, 3, 4, 4
x = torch.arange(B*C*H*W).view(B, C, H, W)
kernel_h, kernel_w = 2, 2
stride = 2
patches = x.unfold(2, kernel_h, stride).unfold(3, kernel_w, stride)
print(patches )