Say i have a float tensor A with dim[n][n], and 2 int tensor X,Y with dim[N].
each value in X,Y is a int between [0,n-1]
now i want to create a tensor B with dim[N,N], satisfying B[i][j]=A[X[i]][Y[j]]
help me, plz
Say i have a float tensor A with dim[n][n], and 2 int tensor X,Y with dim[N].
each value in X,Y is a int between [0,n-1]
now i want to create a tensor B with dim[N,N], satisfying B[i][j]=A[X[i]][Y[j]]
help me, plz
If I understand your use case correctly, this should work:
N = 5
A = torch.randn(N, N)
x, y = torch.randint(0, N, (N,)), torch.randint(0, N, (N,))
B = torch.zeros(N, N)
# slow approach for verification
for i in range(N):
for j in range(N):
B[i, j] = A[x[i], y[j]]
# avoiding loops
C = A[x.unsqueeze(1), y]
# check for equal values
print((B == C).all())
> tensor(True)
thx, really helpful!
can i further avoid the outer loop for batch ?
Say
A.size=(Batch,N,N)
x.size=y.size=(Batch,N)
for b in range(Batch):
B[b]=A[b,x[b].unsqueeze(1),y[b]]
Yes, this should work:
batch_size, N = 4, 5
A = torch.randn(batch_size, N, N)
x, y = torch.randint(0, N, (batch_size, N)), torch.randint(0, N, (batch_size, N))
B = torch.zeros(batch_size, N, N)
# loop approach
for b in range(batch_size):
B[b]=A[b,x[b].unsqueeze(1),y[b]]
# avoiding loops
C = A[torch.arange(A.size(0)).unsqueeze(1).unsqueeze(2), x.unsqueeze(2), y.unsqueeze(1)]
# check for equal values
print((B == C).all())
> tensor(True)
Thanks very much! it works