Most efficient, differentiable Way to get differentiable pair-wise Distance

Hi all!
I look for the most efficient, differentiable way for a 3D PointCloud matrix with shape (1024,3) to find the vector containing the pairwise distances (shape: (1024x1024,1).
Currently, I use this:

x = torch.ones(1024, 3, requires_grad=True)
pdist = nn.PairwiseDistance(p=2, keepdim=True) 
out = torch.cat ([ pdist(x[n], x[i]) for n in range (len (x)) for i in range (len (x))])

which is not differentiable:

torch.autograd.grad(out.sum(), x)[0]

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_1030593/3913672061.py in <module>
----> 1 torch.autograd.grad(out.sum(), x)[0]

~/miniconda3/envs/test_torch/lib/python3.7/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
    300         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    301             t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
--> 302             allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass
    303 
    304 

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Thank you so much!

Your code works for me:

x = torch.ones(10, 3, requires_grad=True)
pdist = nn.PairwiseDistance(p=2, keepdim=True) 
out = torch.cat ([ pdist(x[n], x[i]) for n in range (len (x)) for i in range (len (x))])
out.mean().backward()
print(x.grad)
# tensor([[ 9.3132e-10,  9.3132e-10,  9.3132e-10],
#         [ 9.3132e-10,  9.3132e-10,  9.3132e-10],
#         [ 9.3132e-10,  9.3132e-10,  9.3132e-10],
#         [ 9.3132e-10,  9.3132e-10,  9.3132e-10],
#         [-9.3132e-10, -9.3132e-10, -9.3132e-10],
#         [-9.3132e-10, -9.3132e-10, -9.3132e-10],
#         [-9.3132e-10, -9.3132e-10, -9.3132e-10],
#         [-9.3132e-10, -9.3132e-10, -9.3132e-10],
#         [-9.3132e-10, -9.3132e-10, -9.3132e-10],
#         [-9.3132e-10, -9.3132e-10, -9.3132e-10]])
1 Like

I think I forgot the requires_grad=True line in my original code … ^^
Thank you Patrick!
As a side question, do you know if there is a more efficient (built-in) way to a arrive at this point wise distance for x?

Yes, you could add dimensions and allow broadcasting to avoid the for loop:

x = torch.randn(10, 3, requires_grad=True)
pdist = nn.PairwiseDistance(p=2, keepdim=True) 
ref = torch.cat ([ pdist(x[n], x[i]) for n in range (len (x)) for i in range (len (x))])

out = pdist(x.unsqueeze(1), x.unsqueeze(0))
out = out.view(-1)
print((out - ref).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)
1 Like

Oh wow, that is amazing! Does this technique have a specific name or can I find more information on that in the PyTorch documentation? The computation time is now seconds instead of minutes :slight_smile:

I would refer to it as broadcasting and the ability to execute operators on batched inputs. This doc explains it a bit more and also links to the numpy docs, which has more examples.

1 Like