Hi everyone. Recently I write a function to simulate a complex homography transform. I firstly deal with the output of the network (resnet18) and get the transformed grid using my written function. Then I transform the random tensor and compute the loss. However, I find the loss.backward() is very slow.
My code is as following:
# -*-coding:utf-8-*-
import torch
import torch.nn.functional as F
import pdb
from torch import optim
import sys
import sys
sys.path.append('/home/yongjie/uda_for_convex/pose_estimation')
from model import ResNet18
import math
import time
from tqdm import tqdm
def generate_grid(alpha, beta, d):
size = (1, 3, 720, 720)
N, C, H, W = size
B = N
Rotx = torch.zeros(B, 3, 3).to(device).clone()
ones = torch.ones(B,).to(device).clone()
# pdb.set_trace()
Rotx[:, 0, 0] = ones
Rotx[:,1, 1] = torch.cos(beta).squeeze(1)
Rotx[:,1, 2] = -torch.sin(beta).squeeze(1)
Rotx[:,2, 1] = torch.sin(beta).squeeze(1)
Rotx[:,2, 2] = torch.cos(beta).squeeze(1)
Roty = torch.zeros(B, 3, 3).to(device).clone()
ones = torch.ones(B,).to(device).clone()
Roty[:,1,1] = ones.clone()
Roty[:,0,0] = torch.cos(alpha).squeeze(1)
Roty[:,0,2] = torch.sin(alpha).squeeze(1)
Roty[:,2,0] = -torch.sin(alpha).squeeze(1)
Roty[:,2,2] = torch.cos(alpha).squeeze(1)
# construct homo
R = torch.bmm(Rotx, Roty)
R_1 = torch.inverse(R).clone()
t = torch.zeros(B,3).to(device)
# pdb.set_trace()
t[:,2] = d.squeeze(1) # translation vector
R_1[:,:,2] = t
temp_homo = R_1
homo = torch.inverse(R_1)
# -------------------
# construct the circle and find the center and scale
C = torch.zeros(B, 3, 3).to(device)
C[:,0,0] = torch.tensor(1.)
C[:,1,1] = torch.tensor(1.)
C[:,2,2] = torch.tensor(-1.)
C2 = torch.bmm(torch.inverse(torch.transpose(temp_homo,1,2)), C)
C2_ = torch.bmm(C2, torch.inverse(temp_homo))
C3 = torch.inverse(C2_) # dual format
a = C3[:,0,0]
b = C3[:,0,2]+C3[:,2,0]
c = C3[:,2,2]
right_x = (-b-torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
left_x = (-b+torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
right_x = -1./right_x
left_x = -1./left_x
width = right_x-left_x
center_x = (right_x+left_x)/2
a_ = C3[:,1,1]
b_ = C3[:,1,2]+C3[:,2,1]
c_ = C3[:,2,2]
bottom_y = (-b_-torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
top_y = (-b_+torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
bottom_y = -1./bottom_y
top_y = -1./top_y
height = bottom_y-top_y
center_y = (top_y+bottom_y)/2
scale = torch.max(width, height)
#---------------------
# generate the compact grid according the homo, center and scale
# size = (1, 3, 1024, 1024)
N, C, H, W = size
N=B
base_grid = torch.zeros(N, H, W, 2).to(device)
linear_points = torch.linspace(-1, 1, W).to(device) if W > 1 else torch.Tensor([-1]).to(device)
base_grid[:, :, :, 0] = torch.ger(torch.ones(H).to(device), linear_points).expand_as(base_grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H).to(device) if H > 1 else torch.Tensor([-1]).to(device)
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W).to(device)).expand_as(base_grid[:, :, :, 1])
base_grid = base_grid.view(N, H * W, 2)
# transform the center and scale
center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
base_grid = base_grid*scale/2
base_grid = base_grid+center
# extend the homo, easy to calculate
h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)
temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2
temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
v1 = temp3 / temp4
grid1 = u1.view(N, H, W, 1)
grid2 = v1.view(N, H, W, 1)
grid = torch.cat((grid1, grid2), 3)
return grid
device = 2
BS = 1
predictor = ResNet18(in_channel=3, num_classes=1).to(device)
optimizer = optim.SGD(predictor.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.005)
for i in tqdm(range(100)):
optimizer.zero_grad()
images_source = torch.rand(BS, 3, 720, 720).to(device)
temp_image2 = F.interpolate(images_source, size=(256,256), mode='bilinear')
output_four = predictor(temp_image2)
k_p = output_four[0]
alpha_p = output_four[1]
beta_p = output_four[2]
d_p = output_four[3]
K_label = (k_p*(-0.22)+(-0.5))
K_up = (K_label / ((1. * K_label + 1.) ** 2 + 0.0000000001))
alpha = ((alpha_p * 120.-60.) * math.pi/180.)
beta = ((beta_p * 60.-30.) * math.pi/180.)
d = (d_p * 6.+2.)
input_tensor = torch.rand(BS, 3, 720, 720).to(device)
homo_grid = generate_grid(alpha, beta, d)
h_t = F.grid_sample(input_tensor, homo_grid)
temp_loss = F.mse_loss(h_t, torch.tensor([1.]).to(device))
start = time.time()
temp_loss.backward()
print(time.time()-start)
optimizer.step()
The output is
I’m not sure if this phenomenon is related to some internal function in pytorch such as torch.inverse or torch.sqrt. Could you give me some advice? Thanks very much.