Understanding cdist() function

See the runtime error where I try to use another my_cdist() function.

def my_cdist(x1, x2):
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
    res = res.clamp_min_(1e-30).sqrt_()
    return res

For your reference, please also see this addernet github issue

As for replacing torch.cat(), there is no way to replace the function yet.