I’m trying to use the torch.cdist function in pytorch. However, there is a little bit difference between my definition of distance with the definition of the function.
For example, for ‘torch.cdist’
a = torch.tensor([[1., 0., 1.], [1., 0., 1.], [1., 0., 1.]])
b = torch.tensor([[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]])
dis = torch.cdist(a,b,p=2)
print(dis)
tensor([[1.7321, 1.7321, 1.7321],
[1.7321, 1.7321, 1.7321],
[1.7321, 1.7321, 1.7321]])
Here,
# dis[i,j] = sqrt(sum((a[i] - b[j])**2))
However, in my definition,
# for dis[i, j]
# make i, j element of a[i] - b[j] equals to 0
# then dis[i,j] = sqrt(sum((a[i] - b[j])**2))
My question is that is there any api of source code that allow me to custom this kind of modification?