Hello, I am having a very strange issue. When doing backwards, the following error code pops up:
Here is a code snippet to reproduce the error.
import torch
def geometric_line_distances(joints):
batch, _, _ = joints.size(0), joints.size(1), joints.size(2)
distances = torch.cdist(joints, joints).to(joints.device)
line_dists2 = torch.zeros(batch, 21, 210).to(joints.device)
for k in range(21):
cnt = 0
for m in range(21):
n = 21 - (m + 1)
if m != k:
size = distances[:, m, m+1:21].size(-1)
s1 = .5*(distances[:, m, m+1:21] +
distances[:, k, m].unsqueeze(-1).repeat(1, size) +
distances[:, m+1:21, k])
x1 = (s1 - distances[:, k, m].unsqueeze(-1).repeat(1, size)) * \
(s1 - distances[:, m, m+1:21]) * \
(s1 - distances[:, m+1:21, k])
x1[x1 < 0] = 1e-10
inter = 2 * torch.sqrt(s1 * x1) / distances[:, m, m+1:21]
line_dists2[:, k, cnt:cnt + inter.size(-1)] = inter
cnt += inter.size(-1)
return line_dists2.contiguous()
""" To reproduce error """
x = torch.randn((4, 21, 3), requires_grad=True)
mse = torch.nn.MSELoss()
dist = geometric_line_distances(x)
loss = mse(dist, .5*dist.ditach())
loss.backward()
Any idea what’s going on?