Let s say I have a list consist of K tensors
L = [ torch.rand(B,C,D,D) for _ in range(K)]
for simplicity lets say B=1
I want to find the max value and the corresponding index to the max value of each element.
How I can do that in an efficient way?
Example:
Lets say I have:
B = 1
C = 1
D = 3
K = 2
L = [ torch.rand(B,C,D,D) for _ in range(K)]
print(L)
[tensor([[[[ 0.9226, 0.3428, 0.5824],
[ 0.4465, 0.5420, 0.3884],
[ 0.2781, 0.2483, 0.6952]]]]), tensor([[[[ 0.9592, 0.1106, 0.9677],
[ 0.3140, 0.6602, 0.7806],
[ 0.8007, 0.5433, 0.7550]]]])]
and then if i want the max and index of the max value at C=0
I can do this:
A = torch.zeros(0,D,D)
A1 = torch.zeros(0,D,D)
A2 = torch.zeros(0,D,D)
A1= torch.unsqueeze(L[0][0,0,:,:],0)
A2 = torch.unsqueeze(L[1][0,0,:,:],0)
A = torch.cat((A1,A2),0)
values , index= torch.max(A,0)
print(values)
print(index)
tensor([[ 0.9592, 0.3428, 0.9677],
[ 0.4465, 0.6602, 0.7806],
[ 0.8007, 0.5433, 0.7550]])
tensor([[ 1, 0, 1],
[ 0, 1, 1],
[ 1, 1, 1]])
but this is not efficient for cases when i have big B
and C
,
Any suggestions?