The used padding is too naive in my example and you might want to use e.g. divmod
to calculate the padding size:
def unfold_tensor (x, step_c, step_h, step_w):
kc, kh, kw = step_c, step_h, step_w # kernel size
dc, dh, dw = step_c, step_h, step_w # stride
nc, remainder = np.divmod(x.size(1), kc)
nc += bool(remainder)
nh, remainder = np.divmod(x.size(2), kh)
nh += bool(remainder)
nw, remainder = np.divmod(x.size(3), kw)
nw += bool(remainder)
pad_c, pad_h, pad_w = nc*kc - x.size(1), nh*kh - x.size(2), nw*kw - x.size(3)
x = F.pad(x, ( 0, pad_h, 0, pad_w, 0, pad_c))
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.reshape(-1,unfold_shape[1]*unfold_shape[2]*unfold_shape[3], unfold_shape[4]*unfold_shape[5]*unfold_shape[6])
return patches, unfold_shape
def fold_tensor (x, shape_x, shape_orginal):
x = x.reshape(-1,shape_x[1], shape_x[2], shape_x[3], shape_x[4], shape_x[5], shape_x[6])
x = x.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
#Fold
output_c = shape_x[1] * shape_x[4]
output_h = shape_x[2] * shape_x[5]
output_w = shape_x[3] * shape_x[6]
x = x.view(1, output_c, output_h, output_w)
return x
b,c,h,w = 1, 8, 28, 28
x = torch.randint(10, (b,c,h,w))
print(x.shape)
shape_original = x.size()
patches, shape_patches = unfold_tensor (x, c, 3, 3)
print(patches.shape)
fold_patches = fold_tensor(patches, shape_patches, shape_orginal)
print(fold_patches.shape)
fold_patches = fold_patches[:, :shape_original[1], :shape_original[2], :shape_original[3]]
print(fold_patches.shape)
print((x == fold_patches).all())
> tensor(True)
Currently the padding is only applied to one side and you can try to split it so that both sides will be padded.