I my Pytorch training i use a composite loss function defined as : .
In order to update the weights alpha and beta, i need to compute three values :
which are the the means of the gradients of the loss terms w.r.t to all the weights in the network.
Is there an efficient way to write it in pytorch ?
My training code look like :
for epoc in range(1, nb_epochs+1):
#init
optimizer_fo.zero_grad()
#get the current loss
loss_total = mynet_fo.loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n)
#compute gradients
loss_total.backward(retain_graph=True)
#optimize
optimizer_fo.step()
Where my .loss() function directly return the sum of the terms. I’ve thinking of make a second forward pass and call backward on each Loss term independently but it would be very expensive.