Custom cdist in pytorch

I’m trying to use the torch.cdist function in pytorch. However, there is a little bit difference between my definition of distance with the definition of the function.

For example, for ‘torch.cdist’

a = torch.tensor([[1.,  0., 1.], [1., 0., 1.], [1., 0., 1.]])
b = torch.tensor([[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]])
dis = torch.cdist(a,b,p=2)
print(dis)
tensor([[1.7321, 1.7321, 1.7321],
        [1.7321, 1.7321, 1.7321],
        [1.7321, 1.7321, 1.7321]])

Here,

# dis[i,j] = sqrt(sum((a[i] - b[j])**2))

However, in my definition,

# for dis[i, j]
# make i, j element of a[i] - b[j] equals to 0
# then dis[i,j] = sqrt(sum((a[i] - b[j])**2))

My question is that is there any api of source code that allow me to custom this kind of modification?

Give an example and the correct result of your case, it seems not quite clear.