Issue with Pytorch-Geometric (clustering) and Batches

Hi *,

I’m a bit new to pytorch-geometric and it seems that there is something I’m missing when processing a batch using one the methods in torch-cluster.

My tensor shape is [batch_size, 10, 8] so 10 data points with 8 features per batch (it actually post the embedding stage do these are embedding vectors).

Next, I wish to try and find the closest ones in this embedded space so I could focus the training on these. But, I cannot succeed in running over the batched tensor.

I’m using

from torch_cluster import radius_graph

x = model(input_data)
# [x] = [batch_size, 10, 8]
radius_graph(x, r=1) --> Fails

radius_graph(x[0], r=1) --> Works

I cannot understand the batch parameter the function is asking for, and its always asserting over the dimensions of the input tensor.

I would expect to get a returned value of

[batch_size, 2, n_matched]

Any help would be appreciated. Thank you