Hi,
I have problems with my weight updates when calling optimizer.step()
after manually calculating and assigning the gradients. I checked that .grad
for all parameters is not None
or zero, but when I call optimizer.state_dict()
I get the unusual output
{'state': {0: {'momentum_buffer': None}, 1: {'momentum_buffer': None}, 2: {'momentum_buffer': None}, 3: {'momentum_buffer': None}, 4: {'momentum_buffer': None}, 5: {'momentum_buffer': None}, 6: {'momentum_buffer': None}, 7: {'momentum_buffer': None}, 8: {'momentum_buffer': None}}, 'param_groups': [{'lr': 0.005, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8]}]}
I share the code that I am using below:
from functorch import jacrev, vmap, make_functional, grad
def compute_centered_jacobian(model,samples):
func, parameters = make_functional(model)
def func_(*params):
return func(params, samples)
jac = torch.autograd.functional.jacobian(func_, parameters)
jac_ampl = jac[0]
jac_phase = jac[1]
jac_ampl = torch.cat([it.reshape(it.size(0),-1) for it in list(jac_ampl)], axis=-1)
jac_phase = torch.cat([it.reshape(it.size(0),-1) for it in list(jac_phase)], axis=-1)
return jac_ampl, jac_phase
def _compute_gradient_with_curvature(Tinv, E, O):
n_samples = Tinv.size(0)
TinvE = torch.mv(Tinv, E)/n_samples
δ = torch.einsum("ij,j", O.t(),TinvE)/n_samples
return δ
def compute_gradient_with_curvature(Ore, Oim, E, model,**kwargs):
T = torch.einsum("ij,jk", Ore, Ore.t())+torch.einsum("ij,jk",Oim,Oim.t())
Tinv = torch.linalg.pinv(T,rtol=1e-12)
δ = _compute_gradient_with_curvature(Tinv, E.real, Ore)+_compute_gradient_with_curvature(Tinv, E.imag, Oim)
return δ
def apply_grads(model,grad):
i = 0
for p in filter(lambda x: x.requires_grad, model.parameters()):
n = p.numel()
if p.grad is not None:
p.grad.copy_(grad[i : i + n].view(p.size()))
else:
print("gradient = None. Please check whats going wrong!")
p.grad = grad[i : i + n].view(p.size())
i += 1
def run_sr(model, E, samples, optimizer, scheduler=None):
Ore, Oim = compute_centered_jacobian(model, samples)
grads = compute_gradient_with_curvature(Oim, Ore, E, model)
print(grads)
apply_grads(model,grads)
optimizer.step()
I would be happy about any help!