def pairwise_dist(x, y):
xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x, y.t())
rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))
P = (rx.t() + ry - 2*zz)
return P
def NN_loss(x, y, dim=0):
dist = pairwise_dist(x, y)
values, indices = dist.min(dim=dim)
return values.mean()
def batch_pairwise_dist(a,b):
x,y = a,b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2,1))
yy = torch.bmm(y, y.transpose(2,1))
zz = torch.bmm(x, y.transpose(2,1))
diag_ind = torch.arange(0, num_points).type(torch.cuda.LongTensor)
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2,1) + ry - 2*zz)
return P
def batch_NN_loss(x, y, dim=1):
assert dim != 0
pdb.set_trace()
dist = batch_pairwise_dist(x,y)
values, indices = dist.min(dim=dim)
return values.mean(dim=-1)
Here is how I ended up doing it. Not sure if I’m fetching the diagonal elements as fast as possible but it seems to be working.