Hello everyone,
I’m working on a deep learning model that includes three loss functions. However, I’ve noticed that two of the loss functions have a very slow backpropagation speed, which is significantly affecting the training time.
I suspect that the issue may be related to the way these loss functions are calculated.
for ep in range(args.epochs):
model.train()
t1 = t_1 = default_timer()
t_load, t_train = 0., 0.
train_l2_step = 0
train_l2_full = 0
loss_previous = np.inf
torch.autograd.set_detect_anomaly(True)
for xx, yy, msk, cls in train_loader:
t_load += default_timer() - t_1
t_1 = default_timer()
model.to(device)
loss, physics_loss, coefficient_loss= 0. , 0. , 0.
torch.cuda.empty_cache()
for t in range(0, yy.shape[-2], args.T_bundle):
torch.cuda.empty_cache()
y = yy[..., t:t + args.T_bundle, :]
### auto-regressive training
xx = xx + args.noise_scale *torch.sum(xx**2, dim=(1,2,3), keepdim=True)**0.5 * torch.randn_like(xx)
im, _ = model(xx)
loss += myloss(im, y, mask=msk)
if t == 0:
pred = im
else:
pred = torch.cat((pred, im), dim=-2)
batch_size, x_size, y_size, t_size, channels = xx.shape
U_all , U_t = U_all_compute(xx,
device = device,
upper_bound_x = upper_bound_x,
upper_bound_y = upper_bound_y,
upper_bound_t = upper_bound_t,
lower_bound_t = lower_bound_t,
lower_bound_x = lower_bound_x,
lower_bound_y = lower_bound_y)
U = xx.unsqueeze(0)
lib_poly = lib_poly_compute(P=3,
U_all = U_all,
U = U)
lam,d_tol,maxit,STR_iters,normalize,l0_penalty ,print_best_tol= 1e-5, 1, 50, 10, 2, 1e-4, False
lambda_w = torch.randn([ lib_poly.shape[0],1]).to(device) # 比如16:1
start_time = time.time()
w_best, phi_loss = Train_STRidge(lib_poly=lib_poly,
U_t=U_t,
device=device,
lam=lam,
maxit=maxit,
normalize=normalize,
lambda_w=lambda_w,
l0_penalty=l0_penalty,
print_best_tol=False,
d_tol = d_tol)
phi_loss_copy = phi_loss.clone() #.detach()
physics_loss = physics_loss + phi_loss_copy
if t==0:
A_real = w_best
pred_w = A_real
else:
A_real_copy = A_real.clone() #.detach()
w_best_copy = w_best.clone() #.detach()
coefficient_loss = coefficient_loss + torch.mean((A_real_copy - w_best_copy)**2)
pred_w = torch.cat( (pred_w , w_best) , dim = -1)
xx = torch.cat((xx[..., args.T_bundle:, :], im), dim=-2)
optimizer.zero_grad()
total_loss = loss + coefficient_loss + physics_loss
total_loss.backward()```
And the functions used are defined here:
def Train_STRidge(lib_poly = torch.randn([16, 20, 128, 128, 10, 2], device = “cuda:0”),
U_t = torch.randn([1, 20, 128, 128, 10, 2], device = “cuda:0”),
device = “cuda:0”,
lam = 1e-5,
maxit = 100,
normalize = 2,
lambda_w = torch.randn( [16,1]).to(“cuda:0”),
l0_penalty = 1e-4,
print_best_tol = False ,
d_tol = 1
):
flattened_dim = torch.prod(torch.tensor(lib_poly.shape[1:])).item()
lib_poly = lib_poly.view(lib_poly.shape[0], flattened_dim).transpose(0, 1)
U_t = U_t.reshape(-1,1)
w_best = STRidge(
X0 = lib_poly,
y = U_t ,
lam = lam,
maxit = maxit,
tol = d_tol,
normalize = normalize,
device = device,
lambda_w = lambda_w
)
# myloss = SimpleLpLoss(size_average=False)
err_f = torch.mean(( U_t - lib_poly @ w_best.to(device) ) ** 2)
err_best = err_f.item() * torch.count_nonzero(w_best) + err_f
return w_best, err_best
def U_all_compute(yy = torch.randn([20,128,128,10,3]) ,
device = “cuda:0”,
upper_bound_x = 2.5,
upper_bound_y = 2.5,
upper_bound_t = 1,
lower_bound_t = 0,
lower_bound_x = -2.5,
lower_bound_y = -2.5
):
_, size_x, size_y ,size_t= yy.shape[0], yy.shape[1], yy.shape[2],yy.shape[3]
gridx = torch.tensor(np.linspace(lower_bound_x, upper_bound_x, size_x), dtype=torch.float, device = device)
gridy = torch.tensor(np.linspace(lower_bound_y, upper_bound_y, size_y), dtype=torch.float, device = device)
gridt = torch.tensor(np.linspace(lower_bound_t, upper_bound_t, size_t), dtype=torch.float, device = device)
dx = gridx[1] - gridx[0]
dy = gridy[1] - gridy[0]
dt = gridt[1] - gridt[0]
# U = yy.unsqueeze(0)
U_t = [] #torch.zeros_like(yy)
for i in range(yy.shape[-1]):
yy_channel = yy[..., i:i+1] # 结果形状为 [Batchsize, x, y, t]
du_dx = torch.gradient(yy_channel, spacing=dx, dim=1) # 在x维度上计算梯度 : torch.Size([20, 128, 128, 10])
du_dy = torch.gradient(yy_channel, spacing=dy, dim=2) # 在y维度上计算梯度
du_dt = torch.gradient(yy_channel, spacing=dt, dim=3)
du_dx_dx = torch.gradient(du_dx[0], spacing=dx, dim=1)
du_dy_dx = torch.gradient(du_dy[0], spacing=dx, dim=1)
du_dy_dy = torch.gradient(du_dy[0], spacing=dy, dim=2)
du_dx_dy = torch.gradient(du_dx[0], spacing=dy, dim=2)
# print("du_dx .shape:", du_dx[0].shape) # [20, 128, 128, 10, 1]
# print("du_dy .shape:", du_dy[0].shape) # [20, 128, 128, 10, 1]
# print("du_dt .shape:", du_dt[0].shape)
if i==0:
U_t ,U_x, U_y, U_xy, U_xx, U_yy, U_yx= du_dt[0],du_dx[0],du_dy[0],du_dx_dy[0],du_dx_dx[0],du_dy_dy[0],du_dy_dx[0]
else:
U_t = torch.cat((U_t, du_dt[0]), dim=-1)
U_x = torch.cat((U_x, du_dx[0]), dim=-1)
U_y = torch.cat((U_y, du_dy[0]), dim=-1)
U_xy = torch.cat((U_xy, du_dx_dy[0]), dim=-1)
U_xx = torch.cat((U_xx, du_dx_dx[0]), dim=-1)
U_yy = torch.cat((U_yy, du_dy_dy[0]), dim=-1)
U_yx = torch.cat((U_yx, du_dy_dx[0]), dim=-1)
U_all = torch.stack([yy, U_x, U_y, U_xy, U_xx, U_yy, U_yx], dim=0)
return U_all, U_t
def lib_poly_compute(P = 3,
U_all = torch.randn([10,20,128,128,10,2],device = “cuda:0”),
U = torch.randn([1,20,128,128,10,2],device = “cuda:0”)
):
for i in range(U.shape[0]):
for j in range(1, P+1):
# print(f"U[i,...]**j.shape = {U[i,...]**j.shape}")
if i ==0 and j==1:
lib_poly = U[i:i+1,...]**j
else:
lib_poly = torch.cat( (lib_poly, U[i:i+1,...]**j), dim = 0)
# print(f"U[i,...]**j.shape = {(U[i,...]**j).shape}")
lib_poly = torch.cat( (lib_poly, torch.ones_like(U[0:1, ...], device=U_all.device) ) , dim = 0)
U_poly = lib_poly
for i in range(U_all.shape[0]):
for j in range(U_poly.shape[0]):
lib_poly = torch.cat( (lib_poly, U_all[i:i+1,...] * U_poly[j:j+1,...]), dim = 0)
lib_poly = torch.cat( (lib_poly, U_all), dim = 0 )
return lib_poly
I suspect the issue may be related to the complexity of the tensor operations involved in the loss function calculations.
But I'm not sure how to optimize them.
Has anyone faced similar issues with slow backpropagation due to complex tensor operations? Any help or suggestions would be greatly appreciated!
Thank you!
#pytorch