Hi *,
I’m a bit new to pytorch-geometric and it seems that there is something I’m missing when processing a batch using one the methods in torch-cluster.
My tensor shape is [batch_size, 10, 8] so 10 data points with 8 features per batch (it actually post the embedding stage do these are embedding vectors).
Next, I wish to try and find the closest ones in this embedded space so I could focus the training on these. But, I cannot succeed in running over the batched tensor.
I’m using
from torch_cluster import radius_graph
x = model(input_data)
# [x] = [batch_size, 10, 8]
radius_graph(x, r=1) --> Fails
radius_graph(x[0], r=1) --> Works
I cannot understand the batch parameter the function is asking for, and its always asserting over the dimensions of the input tensor.
I would expect to get a returned value of
[batch_size, 2, n_matched]
Any help would be appreciated. Thank you