How to get the max (and the index of the max value) of tensors in a list

You could use torch.stack instead of torch.cat:

L = [ torch.rand(B,C,D,D) for _ in range(K)]
print(L)

L = torch.stack(L)
L.max(0)
1 Like