How to implement this with fast pytorch function?

Say i have a float tensor A with dim[n][n], and 2 int tensor X,Y with dim[N].
each value in X,Y is a int between [0,n-1]
now i want to create a tensor B with dim[N,N], satisfying B[i][j]=A[X[i]][Y[j]]

help me, plz

If I understand your use case correctly, this should work:

N = 5
A = torch.randn(N, N)
x, y = torch.randint(0, N, (N,)), torch.randint(0, N, (N,))

B = torch.zeros(N, N)

# slow approach for verification
for i in range(N):
    for j in range(N):
        B[i, j] = A[x[i], y[j]]

# avoiding loops
C = A[x.unsqueeze(1), y]

# check for equal values
print((B == C).all())
> tensor(True)

thx, really helpful!

can i further avoid the outer loop for batch ?
Say

A.size=(Batch,N,N)
x.size=y.size=(Batch,N)

for b in range(Batch):
    B[b]=A[b,x[b].unsqueeze(1),y[b]]

Yes, this should work:

batch_size, N = 4, 5
A = torch.randn(batch_size, N, N)
x, y = torch.randint(0, N, (batch_size, N)), torch.randint(0, N, (batch_size, N))

B = torch.zeros(batch_size, N, N)

# loop approach
for b in range(batch_size):
    B[b]=A[b,x[b].unsqueeze(1),y[b]]


# avoiding loops
C = A[torch.arange(A.size(0)).unsqueeze(1).unsqueeze(2), x.unsqueeze(2), y.unsqueeze(1)]

# check for equal values
print((B == C).all())
> tensor(True)
1 Like

Thanks very much! it works