How to get the max (and the index of the max value) of tensors in a list

Let s say I have a list consist of K tensors

L = [ torch.rand(B,C,D,D) for _ in range(K)]

for simplicity lets say B=1
I want to find the max value and the corresponding index to the max value of each element.

How I can do that in an efficient way?

Example:
Lets say I have:

B = 1
C = 1
D = 3
K = 2

L = [ torch.rand(B,C,D,D) for _ in range(K)]
print(L)
[tensor([[[[ 0.9226,  0.3428,  0.5824],
          [ 0.4465,  0.5420,  0.3884],
          [ 0.2781,  0.2483,  0.6952]]]]), tensor([[[[ 0.9592,  0.1106,  0.9677],
          [ 0.3140,  0.6602,  0.7806],
          [ 0.8007,  0.5433,  0.7550]]]])]

and then if i want the max and index of the max value at C=0 I can do this:

A = torch.zeros(0,D,D)
A1 = torch.zeros(0,D,D)
A2 = torch.zeros(0,D,D)
A1= torch.unsqueeze(L[0][0,0,:,:],0)
A2 = torch.unsqueeze(L[1][0,0,:,:],0)

A = torch.cat((A1,A2),0)
values , index= torch.max(A,0)
print(values)
print(index)
tensor([[ 0.9592,  0.3428,  0.9677],
        [ 0.4465,  0.6602,  0.7806],
        [ 0.8007,  0.5433,  0.7550]])
tensor([[ 1,  0,  1],
        [ 0,  1,  1],
        [ 1,  1,  1]])

but this is not efficient for cases when i have big B and C,

Any suggestions?

You could use torch.stack instead of torch.cat:

L = [ torch.rand(B,C,D,D) for _ in range(K)]
print(L)

L = torch.stack(L)
L.max(0)
1 Like