Vmap and autograd for grad of output wrt input

Hi all,

I need to compute the second derivative of the network output wrt input and I need access to the output per sample.

Following Pytorch’s tutorial (Per-sample-gradients — PyTorch Tutorials 2.0.1+cu117 documentation), I managed to construct some simple code:


        u_pred = self.forward(inputs)

        def compute_grad(u_in, x_in): 
            grad_x     = torch.autograd.grad(u_in.unsqueeze(0), x_in, grad_outputs=torch.ones_like(u_in.unsqueeze(0)), retain_graph=True, create_graph=True, only_inputs=True)[0]
            grad_xx    = torch.autograd.grad(grad_x.unsqueeze(0), x_in, grad_outputs=torch.ones_like(u_in.unsqueeze(0)), retain_graph=True, create_graph=True, only_inputs=True)[0]
            return grad_xx

        sample_grads = [compute_grad(u_pred[i], inputs) for i in range(bs)]
        sample_grads = zip(*sample_grads)
        sample_grads = [torch.stack(shards) for shards in sample_grads]
        sample_grads = torch.stack(sample_grads)

It works, but it is super slow.
I also attempted to use vmap, but it does not seem work with adagrad.
Does anybody have a recommendation on how to speedup these computations?