Get grads of parameters w.r.t a loss term in pytorch

I my Pytorch training i use a composite loss function defined as : CodeCogsEqn (9).
In order to update the weights alpha and beta, i need to compute three values : CodeCogsEqn (11)
which are the the means of the gradients of the loss terms w.r.t to all the weights in the network.

Is there an efficient way to write it in pytorch ?

My training code look like :

for epoc in range(1, nb_epochs+1):
  #init
  optimizer_fo.zero_grad()
  #get the current loss
  loss_total = mynet_fo.loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n)
  #compute gradients
  loss_total.backward(retain_graph=True)
  #optimize
  optimizer_fo.step()

Where my .loss() function directly return the sum of the terms. I’ve thinking of make a second forward pass and call backward on each Loss term independently but it would be very expensive.