Howdy!

I’m wondering if there is perhaps a more memory-efficient way to do what I would like to do. The ultimate application is to a physics-informed neural network, but I’ve coded the following toy example to make it clearer what I want:

```
def F(x:torch.TensorType,p:torch.TensorType):
return 4*p*x**2
def f(x:torch.TensorType,p:torch.TensorType):
'''
dF/dx, which we will integrate with respect to x.
'''
return 8*p*x
def batched_integral_f(p:torch.tensor,n_batches:int, n_per_batch:int):
'''
Integrates f(x,p) from 0 to 1 using left valued Riemannian integration in batches of size n_per_batch.
Done indirectly by taking the partial derivative of the F function wrt x.
(bounds of integration are fixed in this example so that the sum is the same as the mean value of collected elements)
'''
num_points = n_batches*n_per_batch
sum_val = torch.tensor(0,dtype=torch.float64)
for i in range(n_batches):
x:torch.Tensor = (torch.arange(n_per_batch,dtype=torch.float64) + i*n_per_batch)/num_points
x.requires_grad_()
x.retain_grad()
F_x = F(x,p)
f_by_grad = torch.autograd.grad(F_x.sum(),x,create_graph=True)[0] # sum() called here because each value of F_x depends on exactly ONE of x, so no cross gradients exist.
x.detach_() # is likely a noop but I don't need to compute further derivatives wrt x, only to p
sum_val += torch.sum(f_by_grad)
return sum_val/ num_points
def main():
p = torch.tensor(4,dtype=torch.float64, requires_grad=True)
int_f = batched_integral_f(p, n_batches=1000, n_per_batch=10000)
F_ref = F(1,p)
print('batched_integral_f: {} F(1,p): {}'.format(int_f.item(),F_ref.item()))
dFdp = torch.autograd.grad(int_f,p)[0]
dF_refdp = torch.autograd.grad(F_ref,p)[0]
print('batched_dFdp: {} dF(1,p)dp: {}'.format(dFdp.item(),dF_refdp.item()))
main()
```

As we could maybe see in my example, I would like to compute the derivative of this integral function `batched_integral_f`

with respect to the function parameter `p`

. An issue that I am having is that all values of x I compute stay inside of the computational graph, since this function computes partial derivatives of `F`

with respect to `x`

. When I sum the values in this way, the memory on the pytorch device will quickly get filled, regardless of the fact that the calculations are being done in a batched way.

Is there a way around this, either through a selective pruning of the autograd graph, or some clever accumulation method that is compatible with an optimizer, which would allow such a batched computation of my loss function?

(Footnote for added context, skip if it’s clear what I’ve asked: For some context on why this is necessary, you can now suppose that the integral function I gave in my example is my network loss. It depends on my network inputs and the internal parameters of the network, and contains terms which include the partial derivative of the network with respect to inputs at certain point. To minimize this loss function, my network optimizer needs to take the derivative of my loss function with respect to the network parameters, and I’d like to compute this loss over an arbitrary number of network input samples.)

Thank you in advance.