Pytorch: Loss backward takes 8 seconds!

Hello everyone,

I’m working on a deep learning model that includes three loss functions. However, I’ve noticed that two of the loss functions have a very slow backpropagation speed, which is significantly affecting the training time.
I suspect that the issue may be related to the way these loss functions are calculated.

   
for ep in range(args.epochs):
    model.train()

    t1 = t_1 = default_timer()
    t_load, t_train = 0., 0.
    train_l2_step = 0
    train_l2_full = 0

    loss_previous = np.inf
    torch.autograd.set_detect_anomaly(True) 
    
    for xx, yy, msk, cls in train_loader:
        t_load += default_timer() - t_1
        t_1 = default_timer()
        model.to(device)
        loss, physics_loss, coefficient_loss= 0. , 0. , 0.
        torch.cuda.empty_cache()
        for t in range(0, yy.shape[-2], args.T_bundle):
            torch.cuda.empty_cache()
            y = yy[..., t:t + args.T_bundle, :]
            ### auto-regressive training
            xx = xx + args.noise_scale *torch.sum(xx**2, dim=(1,2,3), keepdim=True)**0.5 * torch.randn_like(xx)
            im, _ = model(xx)
            loss += myloss(im, y, mask=msk)

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), dim=-2)

            batch_size, x_size, y_size, t_size, channels = xx.shape
            U_all , U_t = U_all_compute(xx,
                                        device = device,
                                        upper_bound_x = upper_bound_x,
                                        upper_bound_y = upper_bound_y,
                                        upper_bound_t = upper_bound_t,
                                        
                                        lower_bound_t = lower_bound_t,          
                                        lower_bound_x = lower_bound_x,
                                        lower_bound_y = lower_bound_y)
            U = xx.unsqueeze(0)
            lib_poly = lib_poly_compute(P=3, 
                                        U_all = U_all,
                                        U = U)

            lam,d_tol,maxit,STR_iters,normalize,l0_penalty ,print_best_tol= 1e-5, 1, 50,  10, 2, 1e-4,  False
            lambda_w = torch.randn([ lib_poly.shape[0],1]).to(device) # 比如16:1
            
            start_time = time.time()
            w_best, phi_loss = Train_STRidge(lib_poly=lib_poly,
                                U_t=U_t,
                                device=device,
                                lam=lam,
                                maxit=maxit,
                                normalize=normalize,
                                lambda_w=lambda_w,
                                l0_penalty=l0_penalty,
                                print_best_tol=False,
                                d_tol = d_tol)
            phi_loss_copy = phi_loss.clone()    #.detach()
            physics_loss = physics_loss + phi_loss_copy
            if t==0:
                A_real = w_best
                pred_w = A_real
            else:
                A_real_copy = A_real.clone()    #.detach()
                w_best_copy = w_best.clone()    #.detach()
                coefficient_loss = coefficient_loss + torch.mean((A_real_copy - w_best_copy)**2) 
                pred_w = torch.cat( (pred_w , w_best) , dim = -1)
            
            xx = torch.cat((xx[..., args.T_bundle:, :], im), dim=-2)

        optimizer.zero_grad() 
        total_loss = loss + coefficient_loss + physics_loss 
        total_loss.backward()

And the functions used are defined here:

def Train_STRidge(lib_poly = torch.randn([16, 20, 128, 128, 10, 2], device = “cuda:0”),
U_t = torch.randn([1, 20, 128, 128, 10, 2], device = “cuda:0”),
device = “cuda:0”,
lam = 1e-5,
maxit = 100,
normalize = 2,
lambda_w = torch.randn( [16,1]).to(“cuda:0”),
l0_penalty = 1e-4,
print_best_tol = False ,
d_tol = 1
):
flattened_dim = torch.prod(torch.tensor(lib_poly.shape[1:])).item()

lib_poly = lib_poly.view(lib_poly.shape[0], flattened_dim).transpose(0, 1) 
U_t = U_t.reshape(-1,1)


w_best = STRidge(
            X0 = lib_poly, 
            y = U_t , 
            lam = lam,  
            maxit = maxit,
            tol = d_tol,
            normalize = normalize,
            device = device,
            lambda_w = lambda_w
            )
# myloss = SimpleLpLoss(size_average=False)
err_f = torch.mean(( U_t - lib_poly @ w_best.to(device)  ) ** 2)
err_best = err_f.item() * torch.count_nonzero(w_best) + err_f
return w_best, err_best     

def U_all_compute(yy = torch.randn([20,128,128,10,3]) ,
device = “cuda:0”,
upper_bound_x = 2.5,
upper_bound_y = 2.5,
upper_bound_t = 1,

      lower_bound_t = 0,          
      lower_bound_x = -2.5,
      lower_bound_y = -2.5
      ):
_, size_x, size_y ,size_t= yy.shape[0], yy.shape[1], yy.shape[2],yy.shape[3]
gridx = torch.tensor(np.linspace(lower_bound_x, upper_bound_x, size_x), dtype=torch.float, device = device)
gridy = torch.tensor(np.linspace(lower_bound_y, upper_bound_y, size_y), dtype=torch.float, device = device)   
gridt = torch.tensor(np.linspace(lower_bound_t, upper_bound_t, size_t), dtype=torch.float, device = device)
        
dx = gridx[1] - gridx[0]  
dy = gridy[1] - gridy[0]  
dt = gridt[1] - gridt[0]  
# U = yy.unsqueeze(0)
U_t = [] #torch.zeros_like(yy)
for i in range(yy.shape[-1]):
    yy_channel = yy[..., i:i+1]    
    du_dx = torch.gradient(yy_channel, spacing=dx, dim=1)  
    du_dy = torch.gradient(yy_channel, spacing=dy, dim=2)  
    du_dt = torch.gradient(yy_channel, spacing=dt, dim=3)  
    du_dx_dx = torch.gradient(du_dx[0], spacing=dx, dim=1) 
    du_dy_dx = torch.gradient(du_dy[0], spacing=dx, dim=1)  
    du_dy_dy = torch.gradient(du_dy[0], spacing=dy, dim=2)  
    du_dx_dy = torch.gradient(du_dx[0], spacing=dy, dim=2)  
                                    
    # print("du_dx .shape:", du_dx[0].shape)  # [20, 128, 128, 10, 1]
    # print("du_dy .shape:", du_dy[0].shape)  # [20, 128, 128, 10, 1]
    # print("du_dt .shape:", du_dt[0].shape) 
    if i==0:
        U_t ,U_x, U_y, U_xy, U_xx, U_yy, U_yx= du_dt[0],du_dx[0],du_dy[0],du_dx_dy[0],du_dx_dx[0],du_dy_dy[0],du_dy_dx[0]
    
    else:
        U_t = torch.cat((U_t, du_dt[0]), dim=-1)
        U_x = torch.cat((U_x, du_dx[0]), dim=-1)
        U_y = torch.cat((U_y, du_dy[0]), dim=-1)
        U_xy = torch.cat((U_xy, du_dx_dy[0]), dim=-1)
        U_xx = torch.cat((U_xx, du_dx_dx[0]), dim=-1)
        U_yy = torch.cat((U_yy, du_dy_dy[0]), dim=-1)
        U_yx = torch.cat((U_yx, du_dy_dx[0]), dim=-1)

U_all = torch.stack([yy, U_x, U_y, U_xy, U_xx, U_yy, U_yx], dim=0)

return U_all, U_t


def lib_poly_compute(P = 3,
U_all = torch.randn([10,20,128,128,10,2],device = “cuda:0”),
U = torch.randn([1,20,128,128,10,2],device = “cuda:0”)
):


for i in range(U.shape[0]):
    for j in range(1, P+1):
        # print(f"U[i,...]**j.shape = {U[i,...]**j.shape}")
        if i ==0 and j==1: 
            lib_poly = U[i:i+1,...]**j
        else:
            lib_poly = torch.cat( (lib_poly, U[i:i+1,...]**j),  dim = 0)
            # print(f"U[i,...]**j.shape = {(U[i,...]**j).shape}")
lib_poly = torch.cat( (lib_poly,  torch.ones_like(U[0:1, ...], device=U_all.device) ) , dim = 0)
U_poly = lib_poly

for i in range(U_all.shape[0]):
    for j in range(U_poly.shape[0]):
        lib_poly = torch.cat( (lib_poly, U_all[i:i+1,...] * U_poly[j:j+1,...]),  dim = 0)
        
lib_poly = torch.cat( (lib_poly, U_all), dim = 0 )
return lib_poly

I suspect the issue may be related to the complexity of the tensor operations involved in the loss function calculations.
But I’m not sure how to optimize them.

Has anyone faced similar issues with slow backpropagation due to complex tensor operations? Any help or suggestions would be greatly appreciated!

Thank you!
#pytorch

Double post from here.