I implemented a second-order method for the purpose of solving the neural network training problem. In every iteration of this second-order method, a quadratic approximation to the objective function is solved (approximately) with a conjugate-gradient-like method (CG). The main effort of each CG-iteration is the calculation of the product of the hessian (evaluated at the current iterate of the second order method) and a vector (dermined iteratively based on the previous CG-iteration). Since those hessian vector products have a high influence on the performance of the algorithm, it is important to be able to compute them as efficiently as possible.
I implemented the following:
output_tensor = net(input) loss_value = loss(output, target) grad_loss = torch.autograd.grad(loss_value, net.parameters(), create_graph=True) numb_weights = sum(p.numel() for p in net.parameters()) gradient = torch.zeros(numb_weights, requires_grad = True).double() ind = 0 for g in grad_loss: numb = g.numel() g_flat = g.view(numb) gradient[ind : ind + numb] = g_flat ind += numb def H_v(s): v = torch.from_numpy(s) v.requires_grad = False z = gradient @ v net.zero_grad() z.backward(retain_graph = True)
The function H_v is passed to the CG-method, where it is used to calculate the required hessian vector products (note that I work with numpy arrays in my algorithm, not with tensors). All of this works and is resonably efficient. However, when I, for the purpose of comparison, implemented everything without Pytorch (the simple two-layer network, its output, loss function value, gradient and hessian-vector-function H_v) I made the following obervation: The Pytorch-Version is much faster calculating the output of the net and the loss function value and also faster when calculating the gradient. However, my pretty straightforwardly implemented hessian-vector function is faster than the the Pytorch-Version shown above, in particular for small batch sizes (e.g. the tensor “input” contains less than 1000 images of the mnist-dataset).
Since I am new to Pytorch and do not have a deep understanding of automatic differentiation, I wonder if there may be a better way of defining the function H_v in order to compute the hessian vector products as fast as possible.