Hi,
I have create a patch extracter and combiner from torch.nn.functional api
code is below,
note that extract_patches_3d
and extract_patches_3ds
have same output, the latter is just shorter.
Also note that when combining patches that overlap, the overlapping elements will be summed.
import torch
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))
Output:
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
torch.Size([36, 2, 2, 2, 2])
tensor([[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 1.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 2., 3.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 4., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 5.],
[ 0., 9.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 6., 7.],
[ 10., 11.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 8., 0.],
[ 12., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 13.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 14., 15.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 16., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 17.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 18., 19.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 20., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 21.],
[ 0., 25.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 22., 23.],
[ 26., 27.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 24., 0.],
[ 28., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 29.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 30., 31.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 32., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 33.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 34., 35.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 36., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 37.],
[ 0., 41.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 38., 39.],
[ 42., 43.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 40., 0.],
[ 44., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 45.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 46., 47.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 48., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 49.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 50., 51.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 52., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 53.],
[ 0., 57.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 54., 55.],
[ 58., 59.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 56., 0.],
[ 60., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 61.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 62., 63.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 64., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 65.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 66., 67.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 68., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 69.],
[ 0., 73.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 70., 71.],
[ 74., 75.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 72., 0.],
[ 76., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 77.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 78., 79.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 80., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 81.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 82., 83.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 84., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 85.],
[ 0., 89.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 86., 87.],
[ 90., 91.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 88., 0.],
[ 92., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 93.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 94., 95.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 96., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 97.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 98., 99.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[100., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 101.],
[ 0., 105.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[102., 103.],
[106., 107.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[104., 0.],
[108., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 109.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[110., 111.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[112., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 113.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 0.],
[114., 115.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[116., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[ 0., 117.],
[ 0., 121.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[118., 119.],
[122., 123.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[120., 0.],
[124., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 125.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[126., 127.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[128., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
tensor(True)
# ignore