Accumulate gradient with nonlinear global function with small memory footprint

Hello,

I need to perform a global optimization on a large subset of my dataset, so that the loss does not put itself easily as a sum of losses.
More precisely, a write an example inspired by this post

loss = 0
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    current_loss_1 = crit_1(pred, target)  
    current_loss_2 = crit_2(pred, target)

    # current graph is appended to existing graph
    loss_1 = loss_1 + current_loss_1
    loss_2 = loss_2 + current_loss_2
    if (i+1)%global_size == 0:
        # every global_size  iterations grad of a large graph
        opt.zero_grad()
        loss = func_1(loss_1)*func_2(loss_2)
        loss.backward()
        # huge graph is cleared here
        opt.step()

func_1 and func_2 are two nonlinear functions.
I suppose, this example would work, but will not fit into memory, since global_size might be large.
Mathematically, not all is lost, since I must take gradients of sums (func_1 is denoted by f_1, idem for the losses):
grad(f_1(\sum loss1_i)f_2(\sum loss2_i)) =
grad(\sum loss1_i) f’_1(\sum loss1_i)f_2(\sum loss2_i) + grad(\sum loss2_i) f_1(\sum loss1_i)f’_2(\sum loss2_i) =
\sum grad( loss1_i) f’_1(\sum loss1_i)f_2(\sum loss2_i) + \sum grad( loss2_i) f_1(\sum loss1_i)f’_2(\sum loss2_i)
So, in principle, I could compute the grad at each small_size step in the dataset, clear the graph and therefore manage memory consumption.
The question is: how to compute the final gradient and put it in the correct place ?
Thanks a lot

I have here a “solution” (not sure it is completely general), but it is very ugly.
Maybe someone could improve upon it…
The “solution” obtained by accumulating graphs, freeing only at the end (does not fit in memory in my case):

import torch
def f1(x):
    return x**2
def df1(x):
    return 2*x
def f2(x):
    return x**3
def df2(x):
    return 3*x**2
data = torch.linspace(1, 10, 10)[:, None]
model = torch.nn.Linear(1, 1)
model.zero_grad()
loss = torch.tensor([0.])
loss1 = torch.zeros_like(loss)
loss2= torch.zeros_like(loss)
for dat in data:
    pred = model(dat)
    loss1 += pred
    loss2 += 2.*pred
loss = f1(loss1)*f2(loss2)
loss.backward()
for name, param in model.named_parameters():
    print(name, param.grad.item())
weight 3198.84375
bias 581.60791015625

In this version I compute the grads and free the graph at each time step (the only one that would fit into memory in a real case):

loss = torch.tensor([0.])
loss1 = torch.zeros_like(loss)
loss2= torch.zeros_like(loss)
loss1_acc = torch.tensor([0.], requires_grad=False)
loss2_acc = torch.tensor([0.], requires_grad=False)
dloss2 = {}
dloss2_acc = {}
dloss1_acc = {}
model.zero_grad()

# dloss2 will contain the cumulated gradients of loss2, and idem for loss1 
# One cannot use the grad for initialization, since zero_grad() keeps the value None of the grad attribute
# I assume the grad will have the same shape as the parameter
for name, param in model.named_parameters():
    dloss2_acc[name] = torch.zeros_like(param)
for dat in data:
    pred = model(dat)
    loss1 = pred
    loss1_acc += loss1
    loss1.backward(retain_graph=True)
    # Here, one records the sum of grads up to loss1
    for name, param in model.named_parameters():
        dloss2[name] = param.grad.clone().detach()
    loss2 = 2.*pred
    loss2_acc += loss2
    loss2.backward()
    # The diffrence in the rhs is the current grad of loss2 
    # which is accumulated in dloss2_acc
    for name, param in model.named_parameters():
        dloss2_acc[name] += param.grad.clone().detach() - dloss2[name]
# Here one recovers the sum of grads of loss1 
for name, param in model.named_parameters():
        dloss1_acc[name] = param.grad - dloss2_acc[name]
# And constructs the final grad of f1(loss1)*f2(loss2) 
for name, param in model.named_parameters():
    param.grad =(df1(loss1_acc)*f2(loss2_acc)*dloss1_acc[name] + 
                           f1(loss1_acc)*df2(loss2_acc)*dloss2_acc[name])


for name, param in model.named_parameters():
    print(name, param.grad.item())
weight 3198.84375
bias 581.60791015625