Pytorch question : loss backward takes 8 seconds!

Hello everyone,

I’m working on a deep learning model that includes three loss functions. However, I’ve noticed that two of the loss functions have a very slow backpropagation speed, which is significantly affecting the training time.
I suspect that the issue may be related to the way these loss functions are calculated.

   
for ep in range(args.epochs):
    model.train()

    t1 = t_1 = default_timer()
    t_load, t_train = 0., 0.
    train_l2_step = 0
    train_l2_full = 0

    loss_previous = np.inf
    torch.autograd.set_detect_anomaly(True) 
    
    for xx, yy, msk, cls in train_loader:
        t_load += default_timer() - t_1
        t_1 = default_timer()
        model.to(device)
        loss, physics_loss, coefficient_loss= 0. , 0. , 0.
        torch.cuda.empty_cache()
        for t in range(0, yy.shape[-2], args.T_bundle):
            torch.cuda.empty_cache()
            y = yy[..., t:t + args.T_bundle, :]
            ### auto-regressive training
            xx = xx + args.noise_scale *torch.sum(xx**2, dim=(1,2,3), keepdim=True)**0.5 * torch.randn_like(xx)
            im, _ = model(xx)
            loss += myloss(im, y, mask=msk)

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), dim=-2)

            batch_size, x_size, y_size, t_size, channels = xx.shape
            U_all , U_t = U_all_compute(xx,
                                        device = device,
                                        upper_bound_x = upper_bound_x,
                                        upper_bound_y = upper_bound_y,
                                        upper_bound_t = upper_bound_t,
                                        
                                        lower_bound_t = lower_bound_t,          
                                        lower_bound_x = lower_bound_x,
                                        lower_bound_y = lower_bound_y)
            U = xx.unsqueeze(0)
            lib_poly = lib_poly_compute(P=3, 
                                        U_all = U_all,
                                        U = U)

            lam,d_tol,maxit,STR_iters,normalize,l0_penalty ,print_best_tol= 1e-5, 1, 50,  10, 2, 1e-4,  False
            lambda_w = torch.randn([ lib_poly.shape[0],1]).to(device) # 比如16:1
            
            start_time = time.time()
            w_best, phi_loss = Train_STRidge(lib_poly=lib_poly,
                                U_t=U_t,
                                device=device,
                                lam=lam,
                                maxit=maxit,
                                normalize=normalize,
                                lambda_w=lambda_w,
                                l0_penalty=l0_penalty,
                                print_best_tol=False,
                                d_tol = d_tol)
            phi_loss_copy = phi_loss.clone()    #.detach()
            physics_loss = physics_loss + phi_loss_copy
            if t==0:
                A_real = w_best
                pred_w = A_real
            else:
                A_real_copy = A_real.clone()    #.detach()
                w_best_copy = w_best.clone()    #.detach()
                coefficient_loss = coefficient_loss + torch.mean((A_real_copy - w_best_copy)**2) 
                pred_w = torch.cat( (pred_w , w_best) , dim = -1)
            
            xx = torch.cat((xx[..., args.T_bundle:, :], im), dim=-2)

        optimizer.zero_grad() 
        total_loss = loss + coefficient_loss + physics_loss 
        total_loss.backward()```
And the functions used are defined here:

def Train_STRidge(lib_poly = torch.randn([16, 20, 128, 128, 10, 2], device = “cuda:0”),
U_t = torch.randn([1, 20, 128, 128, 10, 2], device = “cuda:0”),
device = “cuda:0”,
lam = 1e-5,
maxit = 100,
normalize = 2,
lambda_w = torch.randn( [16,1]).to(“cuda:0”),
l0_penalty = 1e-4,
print_best_tol = False ,
d_tol = 1
):
flattened_dim = torch.prod(torch.tensor(lib_poly.shape[1:])).item()

lib_poly = lib_poly.view(lib_poly.shape[0], flattened_dim).transpose(0, 1) 
U_t = U_t.reshape(-1,1)


w_best = STRidge(
            X0 = lib_poly, 
            y = U_t , 
            lam = lam,  
            maxit = maxit,
            tol = d_tol,
            normalize = normalize,
            device = device,
            lambda_w = lambda_w
            )
# myloss = SimpleLpLoss(size_average=False)
err_f = torch.mean(( U_t - lib_poly @ w_best.to(device)  ) ** 2)
err_best = err_f.item() * torch.count_nonzero(w_best) + err_f
return w_best, err_best     

def U_all_compute(yy = torch.randn([20,128,128,10,3]) ,
device = “cuda:0”,
upper_bound_x = 2.5,
upper_bound_y = 2.5,
upper_bound_t = 1,

      lower_bound_t = 0,          
      lower_bound_x = -2.5,
      lower_bound_y = -2.5
      ):
_, size_x, size_y ,size_t= yy.shape[0], yy.shape[1], yy.shape[2],yy.shape[3]
gridx = torch.tensor(np.linspace(lower_bound_x, upper_bound_x, size_x), dtype=torch.float, device = device)
gridy = torch.tensor(np.linspace(lower_bound_y, upper_bound_y, size_y), dtype=torch.float, device = device)   
gridt = torch.tensor(np.linspace(lower_bound_t, upper_bound_t, size_t), dtype=torch.float, device = device)
        
dx = gridx[1] - gridx[0]  
dy = gridy[1] - gridy[0]  
dt = gridt[1] - gridt[0]  
# U = yy.unsqueeze(0)
U_t = [] #torch.zeros_like(yy)
for i in range(yy.shape[-1]):
    yy_channel = yy[..., i:i+1]      # 结果形状为 [Batchsize, x, y, t]
    du_dx = torch.gradient(yy_channel, spacing=dx, dim=1)  # 在x维度上计算梯度 : torch.Size([20, 128, 128, 10])
    du_dy = torch.gradient(yy_channel, spacing=dy, dim=2)  # 在y维度上计算梯度
    du_dt = torch.gradient(yy_channel, spacing=dt, dim=3)  
    du_dx_dx = torch.gradient(du_dx[0], spacing=dx, dim=1) 
    du_dy_dx = torch.gradient(du_dy[0], spacing=dx, dim=1)  
    du_dy_dy = torch.gradient(du_dy[0], spacing=dy, dim=2)  
    du_dx_dy = torch.gradient(du_dx[0], spacing=dy, dim=2)  
                                    
    # print("du_dx .shape:", du_dx[0].shape)  # [20, 128, 128, 10, 1]
    # print("du_dy .shape:", du_dy[0].shape)  # [20, 128, 128, 10, 1]
    # print("du_dt .shape:", du_dt[0].shape) 
    if i==0:
        U_t ,U_x, U_y, U_xy, U_xx, U_yy, U_yx= du_dt[0],du_dx[0],du_dy[0],du_dx_dy[0],du_dx_dx[0],du_dy_dy[0],du_dy_dx[0]
    
    else:
        U_t = torch.cat((U_t, du_dt[0]), dim=-1)
        U_x = torch.cat((U_x, du_dx[0]), dim=-1)
        U_y = torch.cat((U_y, du_dy[0]), dim=-1)
        U_xy = torch.cat((U_xy, du_dx_dy[0]), dim=-1)
        U_xx = torch.cat((U_xx, du_dx_dx[0]), dim=-1)
        U_yy = torch.cat((U_yy, du_dy_dy[0]), dim=-1)
        U_yx = torch.cat((U_yx, du_dy_dx[0]), dim=-1)

U_all = torch.stack([yy, U_x, U_y, U_xy, U_xx, U_yy, U_yx], dim=0)

return U_all, U_t

def lib_poly_compute(P = 3,
U_all = torch.randn([10,20,128,128,10,2],device = “cuda:0”),
U = torch.randn([1,20,128,128,10,2],device = “cuda:0”)
):

for i in range(U.shape[0]):
    for j in range(1, P+1):
        # print(f"U[i,...]**j.shape = {U[i,...]**j.shape}")
        if i ==0 and j==1: 
            lib_poly = U[i:i+1,...]**j
        else:
            lib_poly = torch.cat( (lib_poly, U[i:i+1,...]**j),  dim = 0)
            # print(f"U[i,...]**j.shape = {(U[i,...]**j).shape}")
lib_poly = torch.cat( (lib_poly,  torch.ones_like(U[0:1, ...], device=U_all.device) ) , dim = 0)
U_poly = lib_poly

for i in range(U_all.shape[0]):
    for j in range(U_poly.shape[0]):
        lib_poly = torch.cat( (lib_poly, U_all[i:i+1,...] * U_poly[j:j+1,...]),  dim = 0)
        
lib_poly = torch.cat( (lib_poly, U_all), dim = 0 )
return lib_poly

I suspect the issue may be related to the complexity of the tensor operations involved in the loss function calculations. 
But I'm not sure how to optimize them.

Has anyone faced similar issues with slow backpropagation due to complex tensor operations? Any help or suggestions would be greatly appreciated!

Thank you!
#pytorch

PyTorch executes CUDA kernels asynchronously. If you are using host timers to profile your code you would need to synchronize the code before starting and stopping the timers to measure the actual kernel execution time.

Thank you for your response! Do you mean that I should add torch.cuda.synchronize() before measuring the time?

Yes, either synchronize or use torch.utils.benchmark which synchronized for you and adds warmup iterations. This post shows an example.

1 Like