Strange error when doing loss.backward()

Hello, I am having a very strange issue. When doing backwards, the following error code pops up:

Here is a code snippet to reproduce the error.

import torch

def geometric_line_distances(joints):
    batch, _, _ = joints.size(0), joints.size(1), joints.size(2)
    distances = torch.cdist(joints, joints).to(joints.device)
    line_dists2 = torch.zeros(batch, 21, 210).to(joints.device)

    for k in range(21):
        cnt = 0
        for m in range(21):
            n = 21 - (m + 1)
            if m != k:
                size = distances[:, m, m+1:21].size(-1)
                s1 = .5*(distances[:, m, m+1:21] +
                         distances[:, k, m].unsqueeze(-1).repeat(1, size) +
                         distances[:, m+1:21, k])
                x1 = (s1 - distances[:, k, m].unsqueeze(-1).repeat(1, size)) * \
                     (s1 - distances[:, m, m+1:21]) * \
                     (s1 - distances[:, m+1:21, k])

                x1[x1 < 0] = 1e-10
                inter = 2 * torch.sqrt(s1 * x1) / distances[:, m, m+1:21]
                line_dists2[:, k, cnt:cnt + inter.size(-1)] = inter
                cnt += inter.size(-1)
    return line_dists2.contiguous()

""" To reproduce error """
x = torch.randn((4, 21, 3), requires_grad=True)
mse = torch.nn.MSELoss()
dist = geometric_line_distances(x)
loss = mse(dist, .5*dist.ditach())
loss.backward()

Any idea what’s going on?

3 Likes

Hey,

The stack trace points to the backward of the repeat() operation.
What are the arguments you give them? Any special case there where some inputs are of size 0?

Yup, that was it! In some cases, the “size” argument in repeat becomes zero. Fixed that and everything runs smoothly.

Thanks for the answer!

Perfect.
I opened an issue to make sure we fix this to behave properly: https://github.com/pytorch/pytorch/issues/45201