Hi,
sorry for inconvenience, this is my first post.
I am trying to implement an FIM Trace estimator from the paper https://arxiv.org/pdf/2012.14193.pdf

Such a function came out of an attempted implementation.
def fim_trace(loss, model, m):
fim = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
if module.weight.requires_grad:
fim[name] = 0.
dloss_w = torch.autograd.grad(loss, module.weight, create_graph=True)[0]
for _ in range(m):
dloss_w = dloss_w.flatten()
z = torch.rand_like(dloss_w)
loss2_w = dloss_w.T @ z
dloss2_w = torch.autograd.grad(loss2_w, module.weight, retain_graph=True)[0]
loss3_2 = z.T @ dloss2_w.flatten()
fim[name] += loss3_2.item()
fim[name] /= m
fim_trace = 0.
for name in fim:
fim_trace += fim[name]
return fim, fim_trace
Is there anyone here who can determine if it is well implemented, and want to help me? ![]()
Greetings