Why loss.backward() is so slow (taking about 20s)

Thank you for your suggestions. Through step-by-step debugging, I found that the slow backpropagation was related to the large computational graph. There are two places in the code that greatly extend the computational graph. The first is

    center_x = center_x.unsqueeze(1)
    center_y = center_y.unsqueeze(1)
    center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
    scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
    base_grid = base_grid*scale/2
    base_grid = base_grid+center

which can be replaced by

        center_x = center_x.unsqueeze(1)
        center_y = center_y.unsqueeze(1)
        # center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
        # scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
        center = torch.cat((center_x,center_y), 1)
        scale = scale
        base_grid = base_grid*scale/2.
        base_grid = base_grid+center

The repeat operation will greatly expand the computational graph.
Other is

        h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)

        temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
        temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
        u1 = temp1 / temp2

        temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
        temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])

which can be replaced by

        h = homo
        temp1 = (h[:, 0, 0] * base_grid[:, :, 0] + h[:, 0, 1] * base_grid[:, :, 1] + h[:, 0, 2])
        temp2 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
        u1 = temp1 / temp2

        temp3 = (h[:, 1, 0] * base_grid[:, :, 0] + h[:, 1, 1] * base_grid[:, :, 1] + h[:, 1, 2])
        temp4 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])

SkyAndCloud once asked a similar question which can be seen in link.

In general, it is better not to greatly expand the computational graph during the forward pass, otherwise, it will cause the backward to be slower.