Unfold a tensor

The used padding is too naive in my example and you might want to use e.g. divmod to calculate the padding size:

def unfold_tensor (x, step_c, step_h, step_w):
    kc, kh, kw = step_c, step_h, step_w  # kernel size
    dc, dh, dw = step_c, step_h, step_w  # stride
    
    nc, remainder = np.divmod(x.size(1), kc)
    nc += bool(remainder)
    
    nh, remainder = np.divmod(x.size(2), kh)
    nh += bool(remainder)
    
    nw, remainder = np.divmod(x.size(3), kw)
    nw += bool(remainder)    
    
    pad_c, pad_h, pad_w = nc*kc - x.size(1),  nh*kh - x.size(2), nw*kw - x.size(3)
    x = F.pad(x, ( 0, pad_h, 0, pad_w, 0, pad_c))
    patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
    unfold_shape = patches.size()    
    patches = patches.reshape(-1,unfold_shape[1]*unfold_shape[2]*unfold_shape[3], unfold_shape[4]*unfold_shape[5]*unfold_shape[6])
    return patches, unfold_shape

def fold_tensor (x, shape_x, shape_orginal):
    x = x.reshape(-1,shape_x[1], shape_x[2], shape_x[3], shape_x[4], shape_x[5], shape_x[6])
    x = x.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
    #Fold
    output_c = shape_x[1] * shape_x[4]
    output_h = shape_x[2] * shape_x[5]
    output_w = shape_x[3] * shape_x[6]
    x = x.view(1, output_c, output_h, output_w)
    return x

b,c,h,w = 1, 8, 28, 28
x = torch.randint(10, (b,c,h,w))
print(x.shape)
shape_original = x.size()
patches, shape_patches = unfold_tensor (x, c, 3, 3)
print(patches.shape)
fold_patches = fold_tensor(patches, shape_patches, shape_orginal)
print(fold_patches.shape)
fold_patches = fold_patches[:, :shape_original[1], :shape_original[2], :shape_original[3]]
print(fold_patches.shape)
print((x == fold_patches).all())
> tensor(True)

Currently the padding is only applied to one side and you can try to split it so that both sides will be padded.

1 Like