Hi,
I am trying to run a GNN with the following dataset structure per graph (batch_size=1):
node_fts: [207, 2],
edge_fts: [1722]
edges: [2, 1722]
I want to aggregate node_fts of a node’s neighborhood to pass into LSTM, but the neighborhood sizes are not equal. For example, first node is connected to 12 nodes while another is connected to 5. I am currently using for-loops to pad sequences as follows:
seqs, lenghts = [], []
for node_idx in range(node_fts.shape[0]):
seqs.append(
dst_node_fts[
torch.argsort(
edge_fts[
:,
][edges[0] == node_idx],
descending=True,
)
]
)
lenghts.append(len(seqs[-1]))
lenghts = torch.tensor(lenghts)
seqs = pad_sequence(seqs, batch_first=True)
packed = nn.utils.rnn.pack_padded_sequence(
seqs, lenghts.to("cpu"), batch_first=True, enforce_sorted=False
)
This is of course absurdly slow since with batch size 64 we loop for 13248 times. Is there a way to vectorize this operation? I can vectorize the lengths tensor as:
lengths = torch_scatter.scatter_add(
torch.ones(batch.edge_index.shape[1]),
batch.edge_index[0],
dim=0,
dim_size=batch.x.shape[0],
)
This gives me a tensor of neighborhood sizes per node. And then:
dst_node_fts = node_fts[edges[1]]
max_neighborhood_size = lengths.max().item() # 19
# tensor of zeros with shape (num_nodes, max_neighborhood_size=19, node_fts_dim=2)
neighborhood_tensor = torch.zeros((batch.x.shape[0], int(max_neighborhood_size), 2))
neighborhood_tensor[node_idx] = dst_node_fts[
torch.argsort(
edge_fts[
:,
][edges[0] == node_idx]
)
]
The last block neighborhood_tensor[node_idx]
is where I’m stuck. I put the node_idx
as illustrative purpose for what I’m trying to accomplish. Is there a way to do this without loops?