@ptrblck this is my example, however, it does not work for recovering back the original tensor when h and w are not divisible to k. This is my code
import torch
from torch.nn import functional as F
def unfold_tensor (x, step_c, step_h, step_w):
kc, kh, kw = step_c, step_h, step_w # kernel size
dc, dh, dw = step_c, step_h, step_w # stride
pad_c, pad_h, pad_w = x.size(1)%kc // 2, x.size(2)%kh // 2, x.size(3)%kw // 2
x = F.pad(x, ( pad_h, pad_h, pad_w, pad_w, pad_c, pad_c))
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.reshape(-1,unfold_shape[1]*unfold_shape[2]*unfold_shape[3], unfold_shape[4]*unfold_shape[5]*unfold_shape[6])
return patches, unfold_shape
def fold_tensor (x, shape_x, shape_orginal):
x = x.reshape(-1,shape_x[1], shape_x[2], shape_x[3], shape_x[4], shape_x[5], shape_x[6])
x = x.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
#Fold
output_c = shape_x[1] * shape_x[4]
output_h = shape_x[2] * shape_x[5]
output_w = shape_x[3] * shape_x[6]
x = x.view(1, output_c, output_h, output_w)
return x
b,c,h,w = 1, 8, 28, 28
x = torch.randint(10, (b,c,h,w))
print (x.shape)
shape_orginal = x.size()
patches, shape_patches = unfold_tensor (x, c, 3, 3)
print (patches.shape)
fold_patches = fold_tensor(patches, shape_patches, shape_orginal)
print (fold_patches.shape)
print((x == fold_patches).all())
Output:
torch.Size([1, 8, 28, 28])
torch.Size([1, 81, 72])
torch.Size([1, 8, 27, 27])
Traceback (most recent call last):
File "test_unfold.py", line 32, in <module>
print((x == fold_patches).all())