Hessian calculation on specific components

Hello,

I need to calculate the gradients of each layer of my model. Then, for each gradient, I need to select k components of it and finally I calculate the Hessian over these k components.

example:
If my gradient has d components, I select k components from d (k<d) and calculate the hessian (sub-hessian) on these k selected components. This means that the hessian (sub-hessian) does not have a form (d,d) but (k, k).

Here’s a small code of how I currently proceed:

import torch
import numpy as np
import torch.nn as nn

def compute_sub_hessian(grad, weights, idxs):
    n_elems = np.prod(grad.shape)
    flat_grad = grad.reshape(-1)
    sub_hessian = torch.zeros(size=(len(idxs),n_elems))
    for i, idx in enumerate(idxs): 
        row_hessian = torch.autograd.grad(flat_grad[idx], weights, retain_graph=True)[0]
        row_hessian = row_hessian.reshape(-1)
        sub_hessian[i] = row_hessian
    sub_hessian = sub_hessian[:, idxs]
    return sub_hessian

class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.fc = nn.Linear(28*28, 10, bias=True)
        return None

    def forward(self, x):
        x = x.reshape(-1, 28*28)
        x = self.fc(x)
        return x

model = MnistModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

X = torch.rand(100, 1, 28, 28)  # 100 images of size 28*28
y = torch.randint(0, 10, size=(100,))

optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward(create_graph=True)

for name, param in model.named_parameters():
    if param.grad is not None:
        idxs = [0, 4]  # not fixed in reallity but doesn't matter
        sub_hessian = compute_sub_hessian(param.grad, param, idxs)
        print(sub_hessian) 

It works, but it is very inefficient. Indeed, the row:

row_hessian = torch.autograd.grad(flat_grad[idx], weights, retain_graph=True)[0]

calculates “d” derivatives, one for each parameter belonging to the weight tensor, instead of “k”.

Does anyone see a better way of doing this?
Thanks for your help.

Hi Leonard!

There’s really no way to avoid computing all “d” derivatives. Pytorch is
designed to work with entire tensors, so it naturally computes the gradient
with respect to the full weights tensor.

If you were to subvert this system somehow so that you only actually
computed “k” derivatives and only actually performed the smaller number
of floating-point operations required to compute the subset of derivatives,
your code would almost certainly run more slowly, even while performing
fewer floating-point operations. This is because streaming the full tensor
computations through the floating-point pipeline and ignoring the unused
computations turns out to be much faster than interrupting those streams
in order to perform just the required computations.

However, you can probably speed things up by pushing your python for
loop in compute_sub_hessian() down into pytorch’s infrastructure by
using “batched” gradients with something along the lines of:

sub_hessian = torch.autograd.grad (flat_grad, weights, grad_outputs = torch.nn.functional.one_hot (torch.tensor (idxs)).float(), is_grads_batched = True)
sub_hessian = sub_hessian[:, idxs]

Note that both this version and your for-loop version compute k * d
individual derivatives and discard the ones that aren’t desired. This
version just (hopefully) speeds things up by using the is_grads_batched
feature (which uses vmap under the hood).

Best.

K. Frank

1 Like

Hi KFrank,

Thanks for the answer and also for the tip to speed things up. This will be particularly useful when k is large.

Best,
Léonard