The error might come from the fact, that you cannot divide the input into 32
-sized patches.
Try to pad the input before unfolding it:
x = torch.randn(1, 172, 220, 156)
kc, kh, kw = 32, 32, 32 # kernel size
dc, dh, dw = 32, 32, 32 # stride
# Pad to multiples of 32
x = F.pad(x, (x.size(2)%kw // 2, x.size(2)%kw // 2,
x.size(1)%kh // 2, x.size(1)%kh // 2,
x.size(0)%kc // 2, x.size(0)%kc // 2))
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.contiguous().view(-1, kc, kh, kw)
# Reshape back
patches_orig = patches.view(unfold_shape)
output_c = unfold_shape[1] * unfold_shape[4]
output_h = unfold_shape[2] * unfold_shape[5]
output_w = unfold_shape[3] * unfold_shape[6]
patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
patches_orig = patches_orig.view(1, output_c, output_h, output_w)
# Check for equality
print((patches_orig == x[:, :output_c, :output_h, :output_w]).all())
> tensor(1, dtype=torch.uint8)