[Adding functionality] Hessian and Fisher Information vector products

Hi,

I implemented Hessian and Fisher Information matrix (FIM) vector products and was wondering if there’d be interest in adding this functionality. The FIM products are optimized, in the sense that they analytically compute the Hessian matrix of the logarithm of the loss wrt. the output layer for a set of commonly used losses.

Best,
P.
Edit: Tagging you in
@smth, @apaszke

Editing for posterity since no one wants to discuss how to submit this PR.

import numpy as np
import torch

def _check_param_device(param, old_param_device):
    r"""This helper function is to check if the parameters are located
    in the same device. Currently, the conversion between model parameters
    and single vector form is not supported for multiple allocations,
    e.g. parameters in different GPUs, or mixture of CPU/GPU.

    Arguments:
        param ([Tensor]): a Tensor of a parameter of a model
        old_param_device (int): the device where the first parameter of a
                                model is allocated.

    Returns:
        old_param_device (int): report device for the first time
    """

    # Meet the first parameter
    if old_param_device is None:
        old_param_device = param.get_device() if param.is_cuda else -1
    else:
        warn = False
        if param.is_cuda:  # Check if in same GPU
            warn = (param.get_device() != old_param_device)
        else:  # Check if in CPU
            warn = (old_param_device != -1)
        if warn:
            raise TypeError('Found two parameters on different devices, '
                            'this is currently not supported.')
    return old_param_device


def vector_to_parameter_list(vec, parameters):
    r"""Convert one vector to the parameter list

    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError('expected torch.Tensor, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None
    params_new = []
    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param_new = vec[pointer:pointer + num_param].view_as(param).data
        params_new.append(param_new)
        # Increment the pointer
        pointer += num_param

    return list(params_new)


def Rop(ys, xs, vs):
    if isinstance(ys, tuple):
        ws = [torch.tensor(torch.zeros_like(y), requires_grad=True) for y in ys]
    else:
        ws = torch.tensor(torch.zeros_like(ys), requires_grad=True)

    gs = torch.autograd.grad(ys, xs, grad_outputs=ws, create_graph=True, retain_graph=True, allow_unused=True)
    re = torch.autograd.grad(gs, ws, grad_outputs=vs, create_graph=True, retain_graph=True, allow_unused=True)
    return tuple([j.detach() for j in re])


def Lop(ys, xs, ws):
    vJ = torch.autograd.grad(ys, xs, grad_outputs=ws, create_graph=True, retain_graph=True, allow_unused=True)
    return tuple([j.detach() for j in vJ])


def HesssianVectorProduct(f, x, v):
    df_dx = torch.autograd.grad(f, x, create_graph=True, retain_graph=True)
    Hv = Rop(df_dx, x, v)
    return tuple([j.detach() for j in Hv])


def FisherVectorProduct(loss, output, model, vp):


    Jv = Rop(output, list(model.parameters()), vp)
    batch, dims = output.size(0), output.size(1)
    if loss.grad_fn.__class__.__name__ == 'NllLossBackward':
        outputsoftmax = torch.nn.functional.softmax(output, dim=1)
        M = torch.zeros(batch, dims, dims).cuda() if outputsoftmax.is_cuda else torch.zeros(batch, dims, dims)
        M.reshape(batch, -1)[:, ::dims + 1] = outputsoftmax
        H = M - torch.einsum('bi,bj->bij', (outputsoftmax, outputsoftmax))
        HJv = [torch.squeeze(H @ torch.unsqueeze(Jv[0],
                                                 -1)) / batch] 
    else:
        HJv = HesssianVectorProduct(loss, output, Jv)
    JHJv = Lop(output, list(model.parameters()), HJv)


    return torch.cat([torch.flatten(v) for v in JHJv])


6 Likes

Cool, thank you for sharing!

Best regards

Thomas

1 Like

Hey,

Thanks for posting the code! Sorry for the inactivity, but we’d be very happy to add those helpers (Rop, Lop, Hv) to autograd. There are some comments that would need to be adressed (in particular why do you keep detaching things?), but they seem to look ok from a quick look! Can you please send a PR?

Thanks,
Adam

I can send a PR, just wanted to first discuss how to structure the code.

I keep detaching, cause for my specific needs didn’t need to keep track of these variables in autograd and wanted to explicitly free up the memory (or at least that was the intention).

The code generally looks good, and we’d want to merge Lop, Rop and Hv. The first two utilities don’t seem to fit autograd very well, and I can’t find any resources about the fisher product (it also looks a bit hacky due to the conditional on the loss function). You’d need to stop detaching too, since that application specific.

Hi @PiotrSokol:

I am wondering:
for R-op, did you follow a similar trick as in here or another else trick?

Thank you.

The Rop above looks strikingly similar to the alternative_rop function in the blog post.

It’s based on the forward on backward trick; they’re the same.

Hello @PiotrSokol.

Did you work on any direct implementation of the FIM ?

@apaszke @smth can you please provide if there is a update on the implementation ?

Thanks.

@Sai_Gayatri_Prasad what is your application?

For network with ReLU activations, FIM is the same as the Hessian and the Gauss-Newton matrix. For general activations, it’s the same as the Gauss-Newton matrix. See Appendix C.2 in https://arxiv.org/pdf/1706.03662.pdf

GN matrix is much cheaper to compute than generic Hessian. For instance, getting diagonal of GN has similar cost to getting the gradient, while diagonal of general Hessian is likely intractable (see Section 6 of http://www.cs.utoronto.ca/~ilya/pubs/2012/CurvProp.pdf)

For neural network with k outputs, you can compute GN matrix with k calls to .backward(). The main tricky thing is figuring out the thing to feed into “.backward” which depends on the loss function used.

Full GN matrix will need >TBs of space even for tiny MNIST networks, so in practice you’ll need to take approximations like diagonal, kfac, etc. An example of doing this for cross-entropy loss is here: https://github.com/cybertronai/autograd-lib

1 Like

Hello,

Sorry about the delayed response and thank you for the insights.

My application is trying to apply the Fisher Information Matrix calculation in https://arxiv.org/abs/1711.08856 in one of my projects

@Sai_Gayatri_Prasad that paper just looks at the trace of FIM, that is much easier to compute. For empirical FIM, it’s the average gradient norm squared. You can compute a batch of gradient norms using the snippet below, and then average them.

A, model = create_toy_model()
activations = {}

def save_activations(layer, a, _):
    activations[layer] = a

with autograd_lib.module_hook(save_activations):
    Y = model(A.t())
    loss = torch.sum(Y * Y) / 2

norms = {}

def compute_norms(layer, _, b):
    a = activations[layer]
    norms[layer] = (a * a).sum(dim=1) * (b * b).sum(dim=1)

with autograd_lib.module_hook(compute_norms):
    loss.backward()

For MNIST MLP, the average gradient norm goes down with time, you won’t see the “inverted V” behavior from the paper

@Yaroslav_Bulatov Thank you for the insight. I did implement an empirical FIM, (average gradient norm squared).

For future reference, theres a pytorch package called Backpack which could be useful: GitHub - f-dangel/backpack: BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.