Why loss.backward() is so slow (taking about 20s)

Hi everyone. Recently I write a function to simulate a complex homography transform. I firstly deal with the output of the network (resnet18) and get the transformed grid using my written function. Then I transform the random tensor and compute the loss. However, I find the loss.backward() is very slow.
My code is as following:

# -*-coding:utf-8-*-
import torch
import torch.nn.functional as F
import pdb
from torch import optim
import sys
import sys
sys.path.append('/home/yongjie/uda_for_convex/pose_estimation')
from model import ResNet18
import math
import time
from tqdm import tqdm




def generate_grid(alpha, beta, d):
    size = (1, 3, 720, 720)
    N, C, H, W = size
    B = N
    Rotx = torch.zeros(B, 3, 3).to(device).clone()
    ones = torch.ones(B,).to(device).clone()
    # pdb.set_trace()
    Rotx[:, 0, 0] =  ones
    Rotx[:,1, 1] = torch.cos(beta).squeeze(1)
    Rotx[:,1, 2] = -torch.sin(beta).squeeze(1)
    Rotx[:,2, 1] = torch.sin(beta).squeeze(1)
    Rotx[:,2, 2] = torch.cos(beta).squeeze(1)
    Roty = torch.zeros(B, 3, 3).to(device).clone()
    ones = torch.ones(B,).to(device).clone()
    Roty[:,1,1] = ones.clone()
    Roty[:,0,0] = torch.cos(alpha).squeeze(1)
    Roty[:,0,2] = torch.sin(alpha).squeeze(1)
    Roty[:,2,0] = -torch.sin(alpha).squeeze(1)
    Roty[:,2,2] = torch.cos(alpha).squeeze(1)
    
    # construct homo
    R = torch.bmm(Rotx, Roty)
    R_1 = torch.inverse(R).clone()  
    t = torch.zeros(B,3).to(device)
    # pdb.set_trace()
    t[:,2] = d.squeeze(1) # translation vector
    R_1[:,:,2] = t  
    temp_homo = R_1
    homo = torch.inverse(R_1)


    
    
    # -------------------
    # construct the circle and find the center and scale
    C = torch.zeros(B, 3, 3).to(device)
    C[:,0,0] = torch.tensor(1.)
    C[:,1,1] = torch.tensor(1.)
    C[:,2,2] = torch.tensor(-1.)
    C2 = torch.bmm(torch.inverse(torch.transpose(temp_homo,1,2)), C)
    C2_ = torch.bmm(C2, torch.inverse(temp_homo))
    C3 = torch.inverse(C2_)  # dual format
    
    a = C3[:,0,0]
    b = C3[:,0,2]+C3[:,2,0]
    c = C3[:,2,2]
    right_x = (-b-torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
    left_x = (-b+torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
    right_x = -1./right_x
    left_x = -1./left_x
    width = right_x-left_x
    center_x = (right_x+left_x)/2
    
    a_ = C3[:,1,1]
    b_ = C3[:,1,2]+C3[:,2,1]
    c_ = C3[:,2,2]
    bottom_y = (-b_-torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
    top_y = (-b_+torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
    bottom_y = -1./bottom_y
    top_y = -1./top_y
    height = bottom_y-top_y
    center_y = (top_y+bottom_y)/2
    scale = torch.max(width, height)
    

    #---------------------
    # generate the compact grid according the homo, center and scale
    # size = (1, 3, 1024, 1024)
    N, C, H, W = size
    N=B
    base_grid = torch.zeros(N, H, W, 2).to(device)
    linear_points = torch.linspace(-1, 1, W).to(device) if W > 1 else torch.Tensor([-1]).to(device)
    base_grid[:, :, :, 0] = torch.ger(torch.ones(H).to(device), linear_points).expand_as(base_grid[:, :, :, 0])
    linear_points = torch.linspace(-1, 1, H).to(device) if H > 1 else torch.Tensor([-1]).to(device)
    base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W).to(device)).expand_as(base_grid[:, :, :, 1])
    base_grid = base_grid.view(N, H * W, 2)
    # transform the center and scale
    center_x = center_x.unsqueeze(1)
    center_y = center_y.unsqueeze(1)
    center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
    scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
    base_grid = base_grid*scale/2
    base_grid = base_grid+center
    
    # extend the homo, easy to calculate
    h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)
    temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
    temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
    u1 = temp1 / temp2
    temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
    temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
    v1 = temp3 / temp4
    grid1 = u1.view(N, H, W, 1)
    grid2 = v1.view(N, H, W, 1)
    grid = torch.cat((grid1, grid2), 3)
    return grid



device = 2
BS = 1
predictor = ResNet18(in_channel=3, num_classes=1).to(device)

optimizer = optim.SGD(predictor.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.005)
for i in tqdm(range(100)):
    
    optimizer.zero_grad()
    
    images_source = torch.rand(BS, 3, 720, 720).to(device)
    
    temp_image2 = F.interpolate(images_source, size=(256,256),  mode='bilinear')
    output_four = predictor(temp_image2)
    
    
    k_p = output_four[0]
    alpha_p = output_four[1]
    beta_p = output_four[2]
    d_p = output_four[3]
    
    K_label = (k_p*(-0.22)+(-0.5))
    K_up = (K_label / ((1. * K_label + 1.) ** 2 + 0.0000000001))
    alpha =  ((alpha_p * 120.-60.) * math.pi/180.)
    beta = ((beta_p * 60.-30.) * math.pi/180.)
    d = (d_p * 6.+2.)
    
    
    
    
    
    
    input_tensor = torch.rand(BS, 3, 720, 720).to(device)
    homo_grid = generate_grid(alpha, beta, d)
    h_t = F.grid_sample(input_tensor, homo_grid)
    
    
    temp_loss = F.mse_loss(h_t, torch.tensor([1.]).to(device))
    start = time.time()
    temp_loss.backward()
    print(time.time()-start)
    optimizer.step()
    

The output is

I’m not sure if this phenomenon is related to some internal function in pytorch such as torch.inverse or torch.sqrt. Could you give me some advice? Thanks very much.

You could profile the code to further isolate the bottleneck of the script, which could help in further debugging. E.g. we’ve been working on the usage of cusolver in more torch.linalg methods, which could speed up your workflow in case you are using the nightly binaries or a source build.

Thank you for your suggestions. Through step-by-step debugging, I found that the slow backpropagation was related to the large computational graph. There are two places in the code that greatly extend the computational graph. The first is

    center_x = center_x.unsqueeze(1)
    center_y = center_y.unsqueeze(1)
    center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
    scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
    base_grid = base_grid*scale/2
    base_grid = base_grid+center

which can be replaced by

        center_x = center_x.unsqueeze(1)
        center_y = center_y.unsqueeze(1)
        # center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
        # scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
        center = torch.cat((center_x,center_y), 1)
        scale = scale
        base_grid = base_grid*scale/2.
        base_grid = base_grid+center

The repeat operation will greatly expand the computational graph.
Other is

        h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)

        temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
        temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
        u1 = temp1 / temp2

        temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
        temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])

which can be replaced by

        h = homo
        temp1 = (h[:, 0, 0] * base_grid[:, :, 0] + h[:, 0, 1] * base_grid[:, :, 1] + h[:, 0, 2])
        temp2 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
        u1 = temp1 / temp2

        temp3 = (h[:, 1, 0] * base_grid[:, :, 0] + h[:, 1, 1] * base_grid[:, :, 1] + h[:, 1, 2])
        temp4 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])

SkyAndCloud once asked a similar question which can be seen in link.

In general, it is better not to greatly expand the computational graph during the forward pass, otherwise, it will cause the backward to be slower.