See the runtime error where I try to use another my_cdist()
function.
@torch.jit.script
def my_cdist(x1, x2):
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
res = res.clamp_min_(1e-30).sqrt_()
return res
For your reference, please also see this addernet github issue
As for replacing torch.cat()
, there is no way to replace the function yet.