Tensor Indexing with Backpropagation

I am trying to do a regression task on some graphs.
The data I have is structured like this:

  • Tensor of floats that stores all of the node features “x” of shape (number of nodes, number of node features)
  • Tensor of all edges “edge_index” that stores the indices of start and end nodes for each edge of shape (2, number of edges)

I want to be able to combine them into a tensor of shape (number of edges, 2*number of node features), i.e. combine the features of the nodes for every edge into a single tensor.
For example:

x = tensor([[ 0,  1,  2],                      #4 nodes with 3 features each 
            [ 3,  4,  5],
            [ 6,  7,  8],
            [ 9, 10, 11]])



edge_index = tensor([[0, 1, 2, 3],             #4 edges. First edge goes from node0 to node1
                     [1, 2, 3, 0]])                      Last edge goes from node3 to node0



result = tensor([[ 0,  1,  2,  3,  4,  5],     #The combined tensor of shape (number of edges, 2*number of node features)
                 [ 3,  4,  5,  6,  7,  8],
                 [ 6,  7,  8,  9, 10, 11],
                 [ 9, 10, 11,  0,  1,  2]])

My first approach was to use torch.index_select, however after implementing it I realized that it doesn’t support backpropagation, therefore to me it was absolutely useless.

My second approach relied on slices:

output = output = torch.zeros([len(edge_index[0]),6], requires_grad = True)
for i in range(len(edge_index[0])):
    output[i]=torch.cat((x[edge_index[0][i].item():edge_index[0][i].item()+1],x[edge_index[1][i].item():edge_index[1][i].item()+1]),1)

This approach gave me the desired result and even supports backpropagation, however it is extremely slow. Processing 1500 nodes with 8000 edges takes nearly 4.5 seconds!

I want to find an approach that both supports backpropagation and is fast.

Any suggestions would be greatly appreciated!
Thank you in advance!

torch.index_select does support backpropagation:

x = torch.randn(3, 3, requires_grad=True)
out = torch.index_select(x, 0, torch.tensor([0, 2]))
out.mean().backward()
print(x.grad)
> tensor([[0.1667, 0.1667, 0.1667],
          [0.0000, 0.0000, 0.0000],
          [0.1667, 0.1667, 0.1667]])

Question deleted. See 124823