The Hutchinson’s estimator (Trace of Fisher Information Matrix)

sorry for inconvenience, this is my first post.

I am trying to implement an FIM Trace estimator from the paper


Such a function came out of an attempted implementation.

def fim_trace(loss, model, m):
    fim = {}
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            if module.weight.requires_grad:
                fim[name] = 0.
                dloss_w = torch.autograd.grad(loss, module.weight, create_graph=True)[0]
                for _ in range(m):
                    dloss_w = dloss_w.flatten()
                    z = torch.rand_like(dloss_w)
                    loss2_w = dloss_w.T @ z
                    dloss2_w = torch.autograd.grad(loss2_w, module.weight, retain_graph=True)[0]
                    loss3_2 = z.T @ dloss2_w.flatten()
                    fim[name] += loss3_2.item()
                fim[name] /= m
    fim_trace = 0.
    for name in fim:
        fim_trace += fim[name]
    return fim, fim_trace

Is there anyone here who can determine if it is well implemented, and want to help me? :slight_smile: