Hi all,
I need to compute the second derivative of the network output wrt input and I need access to the output per sample.
Following Pytorch’s tutorial (Per-sample-gradients — PyTorch Tutorials 2.0.1+cu117 documentation), I managed to construct some simple code:
u_pred = self.forward(inputs)
def compute_grad(u_in, x_in):
grad_x = torch.autograd.grad(u_in.unsqueeze(0), x_in, grad_outputs=torch.ones_like(u_in.unsqueeze(0)), retain_graph=True, create_graph=True, only_inputs=True)[0]
grad_xx = torch.autograd.grad(grad_x.unsqueeze(0), x_in, grad_outputs=torch.ones_like(u_in.unsqueeze(0)), retain_graph=True, create_graph=True, only_inputs=True)[0]
return grad_xx
sample_grads = [compute_grad(u_pred[i], inputs) for i in range(bs)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
sample_grads = torch.stack(sample_grads)
It works, but it is super slow.
I also attempted to use vmap, but it does not seem work with adagrad.
Does anybody have a recommendation on how to speedup these computations?