How to do sort subsampling

Can any one help me with this question?
Lets say i have a 4x4 tensor and i want to do subsampling in the following way, so for each 2x2 block i put the smallest elements together,
then the second smallest elements, and so on, so the out put will be 4 tensors with size half of the input. Stride will be = 2.

Here is an example:

Input: 
A = 
[1     2      3     4
 5     6      7     8
 9     10    11    12
13     14    15    16 ]

Results:
A1 = 
[1     3
 9    11]

A2 = 
[2     4
 9    12]

A3 = 
[5     7
 13   15]

A4 = 
[6      8
 14    16]

And that was just a toy example, the real question is how to do it for a tensor with dimention 'BxCxMxM`

This code should work:

kh, kw = 2, 2
dh, dw = 2, 2
input = torch.randint(10, (1, 2, 4, 4))
input_windows = input.unfold(2, kh, dh).unfold(3, kw, dw)
input_windows = input_windows.contiguous().view(*input_windows.size()[:-2], -1)
input_windows_sorted = input_windows.sort(descending=True)[0]
input_windows_sorted = input_windows_sorted.view(*input.size())
input_windows_sorted = input_windows_sorted.transpose(2, 3)
input_windows_sorted = input_windows_sorted.unfold(3, kh, kw)

Let me know, if that works for you.

1 Like

It works great!!
I was dealing with it for 3 days and was so hopeless and were just writing loop inside loops lol, and now im just staring at your stunning code and im like how it is possible … :joy::rofl::ok_hand:t2::ok_hand:t2:
basically im googleing every line of your code to understand what is going on :))
Thanks a lot

1 Like

I’m glad it works for you.
I would suggest to set the number of input channels to 1 for easy debugging / understanding.

Maybe there is another way, as I’m not really happy to use unfold twice, so let me know, if this code is not fast enough for your use case. There might be some tweaks I’m not thinking of.

Sure! Thank you!

first I was trying to dig in in the maxpooling function and see how it is written and change it, but it was compiled and couldnot figure it out.
But this should be fine for now, I want to try it on a simple case first.

Just a quick question, when you use the .view, what is the role of * in .view(*input_windows.size()[:-2], -1)?
I haven’t seen * before in view

The * is used to unpack the following tuple or list.
In my code I’m using it to unpack input_window.size()[:-2] into the separate sizes.
Python creates therefore something like this:

tensor.view(*tensor.size(), -1)
# will be unpacked to
tensor.view(1, 1, 2, 2, -1)

You can read more about this operation here.

I see, Got it!
Thanks

Sorry I realized that the code is not working exactly as expected, can you please let me know your opinion.
So here i provide an example:

for example:

kh, kw = 2, 2
dh, dw = 2, 2
input = torch.rand(1,2,4,4)
input_windows = input.unfold(2, kh, dh).unfold(3, kw, dw)
input_windows = input_windows.contiguous().view(*input_windows.size()[:-2], -1)
input_windows_sorted = input_windows.sort(descending=True)[0]
input_windows_sorted = input_windows_sorted.view(*input.size())
input_windows_sorted = input_windows_sorted.transpose(2, 3)
input_windows_sorted = input_windows_sorted.unfold(3, kh, kw)
print(input_windows_sorted.size())

Output: torch.Size([1, 2, 4, 2, 2])

but when i change the size of channels i see:

input = torch.rand(1,2,6,6)
input_windows = input.unfold(2, kh, dh).unfold(3, kw, dw)
input_windows = input_windows.contiguous().view(*input_windows.size()[:-2], -1)
input_windows_sorted = input_windows.sort(descending=True)[0]
input_windows_sorted = input_windows_sorted.view(*input.size())
input_windows_sorted = input_windows_sorted.transpose(2, 3)
input_windows_sorted = input_windows_sorted.unfold(3, kh, kw)
print(input_windows_sorted.size())

OutPut: torch.Size([1, 2, 6, 3, 2])

but the output should be

OutPut: torch.Size([1, 2, 4, 3, 3)

You are right! Thanks for pointing this out.
Here is a (hopefully) fixed version:

kh, kw = 2, 2
dh, dw = 2, 2
input = torch.randint(10, (1,2,6,6))
input_windows = input.unfold(2, kh, dh).unfold(3, kw, dw)
input_windows = input_windows.contiguous().view(*input_windows.size()[:-2], -1)
input_windows_sorted = input_windows.sort(descending=True)[0]
input_windows_sorted = input_windows_sorted.permute(0, 1, 4, 2, 3)
print(input_windows_sorted.size())

yes it works I think!
just a last question here,
if i want to concatenate the results based on channels, is there a faster way than:

B = input_windows_sorted[:,0,:,:,:]
for i in range(1,input_windows_sorted.size(1)):
    B=torch.cat((B,input_windows_sorted[:,i,:,:,:]),1)

for this example B will have the size of [1,8,3,3]

You could use a view for it:

input_windows_sorted = input_windows_sorted.contiguous().view(
    input_windows_sorted.size(0), -1, *input_windows_sorted.size()[-2:])

If the batch_size is known before, you can use it instead of input_windows_sorted.size(0), which makes the code a bit more readable.

1 Like