Reducing memory demands in batched higher order derivative computation (application to Physics Informed Neural Networks)

Howdy!

I’m wondering if there is perhaps a more memory-efficient way to do what I would like to do. The ultimate application is to a physics-informed neural network, but I’ve coded the following toy example to make it clearer what I want:


def F(x:torch.TensorType,p:torch.TensorType):
    return 4*p*x**2

def f(x:torch.TensorType,p:torch.TensorType):
    ''' 
    dF/dx, which we will integrate with respect to x.
    '''
    return 8*p*x

def batched_integral_f(p:torch.tensor,n_batches:int, n_per_batch:int):
    '''
    Integrates f(x,p) from 0 to 1 using left valued Riemannian integration in batches of size n_per_batch.
    Done indirectly by taking the partial derivative of the F function wrt x.
    (bounds of integration are fixed in this example so that the sum is the same as the mean value of collected elements)
    '''
    num_points = n_batches*n_per_batch
    sum_val = torch.tensor(0,dtype=torch.float64)
    for i in range(n_batches):
        x:torch.Tensor = (torch.arange(n_per_batch,dtype=torch.float64) + i*n_per_batch)/num_points
        x.requires_grad_()
        x.retain_grad()
        F_x = F(x,p)
        f_by_grad = torch.autograd.grad(F_x.sum(),x,create_graph=True)[0] # sum() called here because each value of F_x depends on exactly ONE of x, so no cross gradients exist.
        x.detach_() # is likely a noop but I don't need to compute further derivatives wrt x, only to p
        sum_val += torch.sum(f_by_grad)
    return sum_val/ num_points

def main():
    p = torch.tensor(4,dtype=torch.float64, requires_grad=True)
    int_f = batched_integral_f(p, n_batches=1000, n_per_batch=10000)
    F_ref = F(1,p)
    print('batched_integral_f: {} F(1,p): {}'.format(int_f.item(),F_ref.item()))

    dFdp = torch.autograd.grad(int_f,p)[0]
    dF_refdp = torch.autograd.grad(F_ref,p)[0]
    print('batched_dFdp: {} dF(1,p)dp: {}'.format(dFdp.item(),dF_refdp.item()))


main()

As we could maybe see in my example, I would like to compute the derivative of this integral function batched_integral_f with respect to the function parameter p. An issue that I am having is that all values of x I compute stay inside of the computational graph, since this function computes partial derivatives of F with respect to x. When I sum the values in this way, the memory on the pytorch device will quickly get filled, regardless of the fact that the calculations are being done in a batched way.

Is there a way around this, either through a selective pruning of the autograd graph, or some clever accumulation method that is compatible with an optimizer, which would allow such a batched computation of my loss function?
(Footnote for added context, skip if it’s clear what I’ve asked: For some context on why this is necessary, you can now suppose that the integral function I gave in my example is my network loss. It depends on my network inputs and the internal parameters of the network, and contains terms which include the partial derivative of the network with respect to inputs at certain point. To minimize this loss function, my network optimizer needs to take the derivative of my loss function with respect to the network parameters, and I’d like to compute this loss over an arbitrary number of network input samples.)

Thank you in advance.

I did some more digging and it seemed to me that there was no way to prune the computational graph of each created x value. For this reason I deemed it necessary to create a custom autograd class that would compute fresh gradients on each pass.

This newly-created class would compute forward and backward over each batch, inside of the forward() function, summing the results over all batches. backward() would return the summed result.

import torch

class BatchedIntF(torch.autograd.Function):

    @staticmethod
    def forward(p, n_batches:int, n_per_batch:int):
        num_points = n_batches * n_per_batch
        sum_val = torch.tensor(0, dtype=torch.float64)
        grad_val = torch.tensor(0,dtype=torch.float64)
        for i in range(n_batches):
            x: torch.Tensor = (torch.arange(n_per_batch, dtype=torch.float64) + i * n_per_batch) / num_points
            x.requires_grad_()
            p = p.detach().requires_grad_()
            with torch.enable_grad():
                F_x = F(x, p)
                f_by_grad = torch.autograd.grad(F_x.sum(), x, create_graph=True)[0]
                sum_term = torch.sum(f_by_grad)
                grad_val += torch.autograd.grad(sum_term, p, create_graph=False)[0]
            sum_val += sum_term
        return sum_val / num_points, (grad_val/num_points).detach()

    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
        batched_int,summed_grad = output
        ctx.save_for_backward(batched_int,summed_grad)

    @staticmethod
    def backward(ctx: Any, grad_output: Any, grad_grad_val:Any)-> Any:
        _,grad_val = ctx.saved_tensors
        ctx.mark_non_differentiable(grad_val)
        return grad_val*grad_output, None, None # gradient is none with respect to non-tensor inputs.

def batched_integral_f_nn_func(p,n_batches:int,n_per_batch:int):
    res,_ = BatchedIntF.apply(p,n_batches,n_per_batch)
    return res

def F(x:torch.TensorType,p:torch.TensorType):
    return 4*p*x**2



def main():
    p = torch.tensor(4,dtype=torch.float64, requires_grad=True)
    F_ref = F(1,p)
    dF_refdp = torch.autograd.grad(F_ref, p)[0]


    p = torch.tensor(4, dtype=torch.float64, requires_grad=True)
    batched_int_Function_compute = batched_integral_f_nn_func(p,n_batches = 10000,n_per_batch = 10000)
    batched_int_Function_compute.backward()
    dfdp_nn_compute = p.grad

    print('F(1,p): {},Computed F(1,p):{}'.format(F_ref.item(), batched_int_Function_compute.item())
    print('batched_dFdp_NN: {},dF(1,p)dp: {}'.format(dfdp_nn_compute.item(),dF_refdp.item())


main()