for example, in below code:
x = torch.arange(50).view((1,1,5,5)).float()
I will have a tensor like below:
tensor([[[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]]]])
and from official document:
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
: Extracts sliding local blocks from a batched input tensor.
so if I want a kernel of [3,3] sliding through my tensor with zero-padding 1, I would expect some output like this:
unfold = nn.Unfold(kernel_size=3, padding=1)
output = unfold(x)
#should gives
# tensor([[[ 0., 0., 0., 0., 0., 1., 0., 5., 6.],...]]]
make sense right? I apply a 3x3 kernel at the padded tensor, flatten it.
but actually, I need to use
unfold = nn.Unfold(kernel_size=5, padding=1)
output = unfold(x)
I need to set the kernel size to 5 instead of 3!
Just an comparision:
unfold = nn.Unfold(kernel_size=3, padding=1)
output = unfold(x)
tensor([[[ 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 0., 5., 6., 7.,
8., 0., 10., 11., 12., 13., 0., 15., 16., 17., 18.],
[ 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
[ 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 6., 7., 8., 9.,
0., 11., 12., 13., 14., 0., 16., 17., 18., 19., 0.],
[ 0., 0., 1., 2., 3., 0., 5., 6., 7., 8., 0., 10., 11., 12.,
13., 0., 15., 16., 17., 18., 0., 20., 21., 22., 23.],
[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.],
[ 1., 2., 3., 4., 0., 6., 7., 8., 9., 0., 11., 12., 13., 14.,
0., 16., 17., 18., 19., 0., 21., 22., 23., 24., 0.],
[ 0., 5., 6., 7., 8., 0., 10., 11., 12., 13., 0., 15., 16., 17.,
18., 0., 20., 21., 22., 23., 0., 0., 0., 0., 0.],
[ 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0.],
[ 6., 7., 8., 9., 0., 11., 12., 13., 14., 0., 16., 17., 18., 19.,
0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0.]]])
unfold = nn.Unfold(kernel_size=5, padding=1)
output = unfold(x)
tensor([[[ 0., 0., 0., 0., 0., 1., 0., 5., 6.],
[ 0., 0., 0., 0., 1., 2., 5., 6., 7.],
[ 0., 0., 0., 1., 2., 3., 6., 7., 8.],
[ 0., 0., 0., 2., 3., 4., 7., 8., 9.],
[ 0., 0., 0., 3., 4., 0., 8., 9., 0.],
[ 0., 0., 1., 0., 5., 6., 0., 10., 11.],
[ 0., 1., 2., 5., 6., 7., 10., 11., 12.],
[ 1., 2., 3., 6., 7., 8., 11., 12., 13.],
[ 2., 3., 4., 7., 8., 9., 12., 13., 14.],
[ 3., 4., 0., 8., 9., 0., 13., 14., 0.],
[ 0., 5., 6., 0., 10., 11., 0., 15., 16.],
[ 5., 6., 7., 10., 11., 12., 15., 16., 17.],
[ 6., 7., 8., 11., 12., 13., 16., 17., 18.],
[ 7., 8., 9., 12., 13., 14., 17., 18., 19.],
[ 8., 9., 0., 13., 14., 0., 18., 19., 0.],
[ 0., 10., 11., 0., 15., 16., 0., 20., 21.],
[10., 11., 12., 15., 16., 17., 20., 21., 22.],
[11., 12., 13., 16., 17., 18., 21., 22., 23.],
[12., 13., 14., 17., 18., 19., 22., 23., 24.],
[13., 14., 0., 18., 19., 0., 23., 24., 0.],
[ 0., 15., 16., 0., 20., 21., 0., 0., 0.],
[15., 16., 17., 20., 21., 22., 0., 0., 0.],
[16., 17., 18., 21., 22., 23., 0., 0., 0.],
[17., 18., 19., 22., 23., 24., 0., 0., 0.],
[18., 19., 0., 23., 24., 0., 0., 0., 0.]]])