Why loss.backward() is so slow (taking about 20s)

Hi everyone. Recently I write a function to simulate a complex homography transform. I firstly deal with the output of the network (resnet18) and get the transformed grid using my written function. Then I transform the random tensor and compute the loss. However, I find the loss.backward() is very slow.
My code is as following:

``````# -*-coding:utf-8-*-
import torch
import torch.nn.functional as F
import pdb
from torch import optim
import sys
import sys
sys.path.append('/home/yongjie/uda_for_convex/pose_estimation')
from model import ResNet18
import math
import time
from tqdm import tqdm

def generate_grid(alpha, beta, d):
size = (1, 3, 720, 720)
N, C, H, W = size
B = N
Rotx = torch.zeros(B, 3, 3).to(device).clone()
ones = torch.ones(B,).to(device).clone()
# pdb.set_trace()
Rotx[:, 0, 0] =  ones
Rotx[:,1, 1] = torch.cos(beta).squeeze(1)
Rotx[:,1, 2] = -torch.sin(beta).squeeze(1)
Rotx[:,2, 1] = torch.sin(beta).squeeze(1)
Rotx[:,2, 2] = torch.cos(beta).squeeze(1)
Roty = torch.zeros(B, 3, 3).to(device).clone()
ones = torch.ones(B,).to(device).clone()
Roty[:,1,1] = ones.clone()
Roty[:,0,0] = torch.cos(alpha).squeeze(1)
Roty[:,0,2] = torch.sin(alpha).squeeze(1)
Roty[:,2,0] = -torch.sin(alpha).squeeze(1)
Roty[:,2,2] = torch.cos(alpha).squeeze(1)

# construct homo
R = torch.bmm(Rotx, Roty)
R_1 = torch.inverse(R).clone()
t = torch.zeros(B,3).to(device)
# pdb.set_trace()
t[:,2] = d.squeeze(1) # translation vector
R_1[:,:,2] = t
temp_homo = R_1
homo = torch.inverse(R_1)

# -------------------
# construct the circle and find the center and scale
C = torch.zeros(B, 3, 3).to(device)
C[:,0,0] = torch.tensor(1.)
C[:,1,1] = torch.tensor(1.)
C[:,2,2] = torch.tensor(-1.)
C2 = torch.bmm(torch.inverse(torch.transpose(temp_homo,1,2)), C)
C2_ = torch.bmm(C2, torch.inverse(temp_homo))
C3 = torch.inverse(C2_)  # dual format

a = C3[:,0,0]
b = C3[:,0,2]+C3[:,2,0]
c = C3[:,2,2]
right_x = (-b-torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
left_x = (-b+torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
right_x = -1./right_x
left_x = -1./left_x
width = right_x-left_x
center_x = (right_x+left_x)/2

a_ = C3[:,1,1]
b_ = C3[:,1,2]+C3[:,2,1]
c_ = C3[:,2,2]
bottom_y = (-b_-torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
top_y = (-b_+torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
bottom_y = -1./bottom_y
top_y = -1./top_y
height = bottom_y-top_y
center_y = (top_y+bottom_y)/2
scale = torch.max(width, height)

#---------------------
# generate the compact grid according the homo, center and scale
# size = (1, 3, 1024, 1024)
N, C, H, W = size
N=B
base_grid = torch.zeros(N, H, W, 2).to(device)
linear_points = torch.linspace(-1, 1, W).to(device) if W > 1 else torch.Tensor([-1]).to(device)
base_grid[:, :, :, 0] = torch.ger(torch.ones(H).to(device), linear_points).expand_as(base_grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H).to(device) if H > 1 else torch.Tensor([-1]).to(device)
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W).to(device)).expand_as(base_grid[:, :, :, 1])
base_grid = base_grid.view(N, H * W, 2)
# transform the center and scale
center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
base_grid = base_grid*scale/2
base_grid = base_grid+center

# extend the homo, easy to calculate
h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)
temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2
temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
v1 = temp3 / temp4
grid1 = u1.view(N, H, W, 1)
grid2 = v1.view(N, H, W, 1)
grid = torch.cat((grid1, grid2), 3)
return grid

device = 2
BS = 1
predictor = ResNet18(in_channel=3, num_classes=1).to(device)

optimizer = optim.SGD(predictor.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.005)
for i in tqdm(range(100)):

images_source = torch.rand(BS, 3, 720, 720).to(device)

temp_image2 = F.interpolate(images_source, size=(256,256),  mode='bilinear')
output_four = predictor(temp_image2)

k_p = output_four[0]
alpha_p = output_four[1]
beta_p = output_four[2]
d_p = output_four[3]

K_label = (k_p*(-0.22)+(-0.5))
K_up = (K_label / ((1. * K_label + 1.) ** 2 + 0.0000000001))
alpha =  ((alpha_p * 120.-60.) * math.pi/180.)
beta = ((beta_p * 60.-30.) * math.pi/180.)
d = (d_p * 6.+2.)

input_tensor = torch.rand(BS, 3, 720, 720).to(device)
homo_grid = generate_grid(alpha, beta, d)
h_t = F.grid_sample(input_tensor, homo_grid)

temp_loss = F.mse_loss(h_t, torch.tensor([1.]).to(device))
start = time.time()
temp_loss.backward()
print(time.time()-start)
optimizer.step()

``````

The output is

I’m not sure if this phenomenon is related to some internal function in pytorch such as torch.inverse or torch.sqrt. Could you give me some advice? Thanks very much.

You could profile the code to further isolate the bottleneck of the script, which could help in further debugging. E.g. we’ve been working on the usage of cusolver in more `torch.linalg` methods, which could speed up your workflow in case you are using the nightly binaries or a source build.

Thank you for your suggestions. Through step-by-step debugging, I found that the slow backpropagation was related to the large computational graph. There are two places in the code that greatly extend the computational graph. The first is

``````    center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
base_grid = base_grid*scale/2
base_grid = base_grid+center
``````

which can be replaced by

``````        center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
# center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
# scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
center = torch.cat((center_x,center_y), 1)
scale = scale
base_grid = base_grid*scale/2.
base_grid = base_grid+center
``````

The repeat operation will greatly expand the computational graph.
Other is

``````        h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)

temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2

temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
``````

which can be replaced by

``````        h = homo
temp1 = (h[:, 0, 0] * base_grid[:, :, 0] + h[:, 0, 1] * base_grid[:, :, 1] + h[:, 0, 2])
temp2 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
u1 = temp1 / temp2

temp3 = (h[:, 1, 0] * base_grid[:, :, 0] + h[:, 1, 1] * base_grid[:, :, 1] + h[:, 1, 2])
temp4 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
``````

SkyAndCloud once asked a similar question which can be seen in link.

In general, it is better not to greatly expand the computational graph during the forward pass, otherwise, it will cause the backward to be slower.