Suppose we have a tensor **T** with size(b, h * w, d) and an tensor **Index** with size(b, h * w, k)-> with range(0, h*w-1) which stores the k different indexes of tensor **T** for each value of b * h * w.

How to generate indexed tensor from T with size (b, h * w, k, d) **without extra memory cost** ?

A memory cost way to implement this is:

```
T= T.unsqueeze(1).expand(b,h*w,h*w,d)
Index = I.unsqueeze(2).expand(b,h*w,k,d)
res =torch.gather(T,2,Index)
```