Weights dont change during optimizer.step() for custom gradient calculation

Hi,

I have problems with my weight updates when calling optimizer.step() after manually calculating and assigning the gradients. I checked that .grad for all parameters is not None or zero, but when I call optimizer.state_dict() I get the unusual output

{'state': {0: {'momentum_buffer': None}, 1: {'momentum_buffer': None}, 2: {'momentum_buffer': None}, 3: {'momentum_buffer': None}, 4: {'momentum_buffer': None}, 5: {'momentum_buffer': None}, 6: {'momentum_buffer': None}, 7: {'momentum_buffer': None}, 8: {'momentum_buffer': None}}, 'param_groups': [{'lr': 0.005, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8]}]}

I share the code that I am using below:

from functorch import jacrev, vmap, make_functional, grad

def compute_centered_jacobian(model,samples):
    func, parameters = make_functional(model)
    def func_(*params):
        return func(params, samples)
    jac = torch.autograd.functional.jacobian(func_, parameters)
    jac_ampl = jac[0]
    jac_phase = jac[1]
    jac_ampl = torch.cat([it.reshape(it.size(0),-1) for it in list(jac_ampl)], axis=-1)
    jac_phase = torch.cat([it.reshape(it.size(0),-1) for it in list(jac_phase)], axis=-1)
    return jac_ampl, jac_phase


def _compute_gradient_with_curvature(Tinv, E, O):
    n_samples = Tinv.size(0)
    TinvE = torch.mv(Tinv, E)/n_samples
    δ = torch.einsum("ij,j", O.t(),TinvE)/n_samples
    return δ


def compute_gradient_with_curvature(Ore, Oim, E, model,**kwargs):
    T = torch.einsum("ij,jk", Ore, Ore.t())+torch.einsum("ij,jk",Oim,Oim.t())  
    Tinv = torch.linalg.pinv(T,rtol=1e-12)
    δ = _compute_gradient_with_curvature(Tinv, E.real, Ore)+_compute_gradient_with_curvature(Tinv, E.imag, Oim)
    return δ

def apply_grads(model,grad):
    i = 0
    for p in filter(lambda x: x.requires_grad, model.parameters()):
        n = p.numel()
        if p.grad is not None:
            p.grad.copy_(grad[i : i + n].view(p.size()))
        else:
            print("gradient = None. Please check whats going wrong!")
            p.grad = grad[i : i + n].view(p.size())
        i += 1

def run_sr(model, E, samples, optimizer, scheduler=None):
    Ore, Oim = compute_centered_jacobian(model, samples)
    grads = compute_gradient_with_curvature(Oim, Ore, E, model)
    print(grads)
    apply_grads(model,grads)
    optimizer.step()

I would be happy about any help!