I need to perform a global optimization on a large subset of my dataset, so that the loss does not put itself easily as a sum of losses.
More precisely, a write an example inspired by this post
loss = 0 for i, (input, target) in enumerate(dataset): pred = net(input) current_loss_1 = crit_1(pred, target) current_loss_2 = crit_2(pred, target) # current graph is appended to existing graph loss_1 = loss_1 + current_loss_1 loss_2 = loss_2 + current_loss_2 if (i+1)%global_size == 0: # every global_size iterations grad of a large graph opt.zero_grad() loss = func_1(loss_1)*func_2(loss_2) loss.backward() # huge graph is cleared here opt.step()
func_1 and func_2 are two nonlinear functions.
I suppose, this example would work, but will not fit into memory, since global_size might be large.
Mathematically, not all is lost, since I must take gradients of sums (func_1 is denoted by f_1, idem for the losses):
grad(f_1(\sum loss1_i)f_2(\sum loss2_i)) =
grad(\sum loss1_i) f’_1(\sum loss1_i)f_2(\sum loss2_i) + grad(\sum loss2_i) f_1(\sum loss1_i)f’_2(\sum loss2_i) =
\sum grad( loss1_i) f’_1(\sum loss1_i)f_2(\sum loss2_i) + \sum grad( loss2_i) f_1(\sum loss1_i)f’_2(\sum loss2_i)
So, in principle, I could compute the grad at each small_size step in the dataset, clear the graph and therefore manage memory consumption.
The question is: how to compute the final gradient and put it in the correct place ?
Thanks a lot