I my Pytorch training i use a composite loss function defined as : .
In order to update the weights alpha and beta, i need to compute three values :
which are the the means of the gradients of the loss terms w.r.t to all the weights in the network.
Is there an efficient way to write it in pytorch ?
My training code look like :
for epoc in range(1, nb_epochs+1): #init optimizer_fo.zero_grad() #get the current loss loss_total = mynet_fo.loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n) #compute gradients loss_total.backward(retain_graph=True) #optimize optimizer_fo.step()
Where my .loss() function directly return the sum of the terms. I’ve thinking of make a second forward pass and call backward on each Loss term independently but it would be very expensive.