I’m trying to index a tensor t
to compute norms of subtensors included in t
.
Suppose that t
is of shape (8, 3, 6, 6)
, and that I have an indexing tensor idx
of shape (12, 3)
. idx
contains the coordinates of 12 elements I want to extract from t
for each element in the batch. I then compute the norm of the subtensor:
I’ve tried t[idx]
, t[:, idx]
which both don’t work. I have one solution which works :
idx_ = (...,) + tuple(idx.T)
norm = torch.linalg.norm(t[idx_], dim=-1)
But when profiling this, forming idx_
is slower than computing the norm, and it seems quite unwieldy and unnatural. Would there be any faster and more natural way to perform this indexing?