Hessian diagonal, second derivative wrt a single weight

Hello, I’m trying to build the diagonal of the Hessian Matrix of the loss wrt to the weights.

It can be achieved by iterating over each element of the first derivative of the loss and taking it’s derivative wrt to the weights again. The problem is that it is a very heavy computation, and I think very impractical for ‘big’ (not really compared to the really big ones) networks, over 1M parameters.
This can be achieved with the following code

def eval_hessian(loss, weights):
    # https://discuss.pytorch.org/t/compute-the-hessian-matrix-of-a-network/15270/3
    loss_grad = grad(loss, weights, create_graph=True, retain_graph=True)

    g_vector = torch.empty((0,))
    for g in loss_grad:
        g_vector = torch.cat([g_vector, g.contiguous().view(-1)])
    l = g_vector.size(0)
    hessian = torch.zeros(l, l)
    for idx in range(l):
        grad2rd = grad(g_vector[idx], weights, create_graph=True, retain_graph=True)
        g2 = torch.empty((0,))
        for g in grad2rd:
            g2 = torch.cat([g2, g.contiguous().view(-1)])
        hessian[idx] = g2
    return hessian

Since I only need the diagonal of the matrix I tried some modifications that would only take the second derivative wrt to the same weight of the first derivative, and not compute the whole matrix, like the following:

def eval_hessian_diag(loss, weights):
    def cat_tensors(tensors):
        t_tensor = torch.empty((sum(t.numel() for t in tensors),))
        offset = 0
        for t in tensors:
            t_tensor[offset:offset + t.numel()] = t.contiguous().view(-1)
            offset += t.numel()
        return t_tensor
    
    loss_grad = grad(loss, weights, create_graph=True, retain_graph=True)
    g_tensor = cat_tensors(loss_grad)
    w_tensor = cat_tensors(weights)
    
    l = g_tensor.size(0)
    hessian_diag = torch.empty(l)    
    for idx in range(l):
        grad2rd = grad(g_tensor[idx], w_tensor[idx], create_graph=True)
        hessian_diag[l] = grad2rd[0]

    return hessian_diag

But when I do this I get the error:
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
The originates here
grad2rd = grad(g_tensor[idx], w_tensor[idx], create_graph=True)
I found out that the grad needs to be wrt to the whole weight tensor (i.e. the complete matrix of weights of a NN layer), and cannot be computed with respect to a single weight.

So my question is if there is another way to compute the diagonal of the Hessian, or the second derivative without having to compute it wrt to every possible pair of weights.

2 Likes

You certainly have solved this issue by now but to help the next lost soul. The line is supposed to be:

grad2rd = grad(g_tensor[idx], w_tensor[idx], retain_graph=True)

Hello, I attempted the solution you proposed here yet I’m still getting the same error. Can you perhaps spot why?

def eval_hessian_diag(loss, model):
	def cat_tensors(tensors):
		t_tensor = torch.empty((sum(t.numel() for t in tensors),))
		offset = 0
		for t in tensors:
			t_tensor[offset:offset + t.numel()] = t.contiguous().view(-1)
			offset += t.numel()
		return t_tensor
	weights = list(model.parameters())
	loss_grad = grad(loss, weights, create_graph=True, retain_graph=True)
	g_tensor = cat_tensors(loss_grad)
	w_tensor = cat_tensors(weights)
    
	l = g_tensor.size(0)
	hessian_diag = torch.empty(l)   
	for idx in range(l):
		grad2rd = grad(g_tensor[idx], w_tensor[idx], retain_graph=True)
		hessian_diag[l] = grad2rd[0]

	return hessian_diag

Just to clarify the difference between my function and the one above is that this function takes in the model and then gets the weights by performing list(model.parameters()). I did this to make it clear where the weights are coming from.