I am looking for a memory efficient way to construct a KNN tensor prior to a Linear Layer.

An example with K = 5:

I have a pointcloud tensor (N, 3) and KNN-Indices Tensor (N, 5). If I concatenated the pointcloud neighbors the tensor becomes (N, 3, K=5). However, each of the K-neighbors represents data in the original (N, 3) pointcloud, just in different order.

Is there a memory efficient way to construct this (N, 3, K) tensor compatible with linear layers? E.g. using something like K-views of the original pointcloud.

I am unsure what would be the best way to approach this problem and am looking for some ideas/guidance.