Long to wide tensor according to indices

I want to create a wide array from long data.

I have some edge indices, edge_index=tensor([[0,0,0,1,1,1...], [0,2,3,1,2,3,..]]) for node 0 connected to node 0,2,3

associated with three edge features, edge_attr=tensor([[0.1,0.2,0.3], [0.4,0.5,0.6],...])

such that edge_index.shape = 2 x n_edges, and edge_attr.shape= n_edges x 3 (using pytorch geometric’s data object).

I am trying to re-shape the edge attributes to have shape [n_nodes, max_n_edges, 3] so each node, e.g., node 0, has its 3 connections 1st edge feature in one row, and its next edge feature concatenated in the third dimension.

I can think of a pandas pivot way of doing this but the duplicates in the node indices (node 0 has 3 connections) poses somewhat of a problem and is slow… Is there a torch way of doing this?