Extract and Reshape

Is there an efficient way to do the above operation without using the for loop?

One way:
h = 4
w = 4
c = 2
A = given_tensor.reshape(h,w,c,4)

B = zeros(128,4)

Then
count = 0
for i in range(h):
for j in range(w):
for k in range(c):
B[count,:] = A[i,h,k,:]
count = count + 1

I am not sure that I understand this operation correctly.
Could you show the operation with for loop?

Based on the figure I think you could check torch.unfold with a kernel size and stride of 4.

I edited the post to show the loop, hopefully, it makes sense.

Base on your code, view is enough.

import torch 

h = 4
w = 4
c = 2
A = torch.randn(h, w, c, 4)


def func_1(data):
    B = torch.zeros(32,4)
    count = 0
    for i in range(h):
        for j in range(w):
            for k in range(c):
                B[count,:] = A[i, j, k, :]
                count = count + 1
                assert count <= 32
    return B 

def func_2(data):
    return data.view(-1, 4)

print(func_2(A).shape)
print(func_1(A) - func_2(A))
1 Like