I have a tensor y which has the size 1024 x 10 x 10, basically I have 1024 matrices of 10x10
I also have 2 indices tensors a and b size 1024 x 15 and 1024 x 15, to access elements of each 1024 10x10 matrices, I guess this is some form a gather function
Thus, my final output z should be 1024 x 15
Currently I do this with loop enumerations of the first index .
Could you suggest me a more vectorized version or is there a better way? Thank you.
Here is the snippet:
z = [y[i][a[i],b[i]].sum() for i,_ in enumerate(y)] z = torch.tensor(z, dtype=torch.float).to(y.device)