import torch
h = 4
w = 4
c = 2
A = torch.randn(h, w, c, 4)
def func_1(data):
B = torch.zeros(32,4)
count = 0
for i in range(h):
for j in range(w):
for k in range(c):
B[count,:] = A[i, j, k, :]
count = count + 1
assert count <= 32
return B
def func_2(data):
return data.view(-1, 4)
print(func_2(A).shape)
print(func_1(A) - func_2(A))