Extracting subtensors from tensors

I’m trying to index a tensor t to compute norms of subtensors included in t.
Suppose that t is of shape (8, 3, 6, 6), and that I have an indexing tensor idx of shape (12, 3). idx contains the coordinates of 12 elements I want to extract from t for each element in the batch. I then compute the norm of the subtensor:

I’ve tried t[idx], t[:, idx] which both don’t work. I have one solution which works :

idx_ = (...,) + tuple(idx.T)
norm = torch.linalg.norm(t[idx_], dim=-1)

But when profiling this, forming idx_ is slower than computing the norm, and it seems quite unwieldy and unnatural. Would there be any faster and more natural way to perform this indexing?

Could you explain how the idx tensor should be used in indexing t?
Since dim1 of idx is of size 3, would you like to apply each of these values to dim1 to dim3 in t?

Yes, exactly. The 3 values are the indexes in dim1 to dim3.
The resulting tensor should be of shape (8, 12) give or take some squeezing.

Would another representation for idx make this work?

If I understand the use case correctly, this should work:

x = torch.randn(8, 3, 6, 6)
idx = torch.cat((torch.randint(0, 3, (12, 1)), torch.randint(0, 6, (12, 2))), 1)

res = x[:, idx[:, 0], idx[:, 1], idx[:, 2]]
1 Like