Creating and accessing a custom state in optimizer

I am calculating the norm of the gradients in current step during optimisation and would like to access it in the next step.I have tried to access it through optimizer state_dict(). Below is my code for the same.

for e in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for batch_idx, (inputs, target) in enumerate(loader):
    # Initialize grad_prev as a tensor of zeros                 
    if 'prev_grad_norm' not in optimizer.state_dict():
        prev_grad_norm = torch.Tensor(100)
        optimizer_state = optimizer.state_dict()
        optimizer_state['prev_grad_norm'] = prev_grad_norm
        print("1. Norm in Previous Epoch:", grad_prev)
        prev_grad_norm = optimizer.state_dict()['prev_grad_norm']
        print("Norm in Previous Epoch:", prev_grad_norm)

    data = inputs.view(inputs.shape[0], -1)
    outputs = net(data)
    loss = criterion(outputs, labels)
    torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
    for param in net.parameters():
      if param.grad is not None:
          param.grad += torch.randn_like(param.grad) * noise_scale 
    current_norm = torch.max(torch.tensor([torch.norm(p.grad, 2) for p in net.parameters()]))
    print("2. Norm in current Epoch:", current_norm)
    # Metrics
    epoch_loss += loss
    total += labels.size(0)
    correct += (torch.max(, 1)[1] == labels).sum().item()
    print("3. Norm updated in current Epoch:", current_norm)
    optimizer.state_dict()['prev_grad_norm'] = current_norm

Though I see that norm is being calculated at the current step, but when I try to access it through ‘prev_grad_norm’ in the next step, it is showing as None. My purpose of accessing it from previous step is to use that as threshold for clipping in next step.

Why the state is not being passed here?

I don’t think you can store arbitrary data into the optimizer’s state_dict as it should return the state and param_groups keys as seen here, so you might need to store it in a custom state class if possible.

Hi @ptrblck , can you please share any documentation for custom state class implementation? I am fine to access previous norm by any method.Thanks!

I was simply referring to a custom dict that users often save for more hyperparameter and training information. E.g. you will find code such as:

state = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "epochs": nb_epochs,
    "any_additional_information": value

I thus suggested to store the value also in a dict or another class you like instead of the optimizer’s state_dict.