I am calculating the norm of the gradients in current step during optimisation and would like to access it in the next step.I have tried to access it through optimizer state_dict(). Below is my code for the same.
for e in range(epochs): correct, total, epoch_loss = 0, 0, 0.0 for batch_idx, (inputs, target) in enumerate(loader): # Initialize grad_prev as a tensor of zeros if 'prev_grad_norm' not in optimizer.state_dict(): prev_grad_norm = torch.Tensor(100) optimizer_state = optimizer.state_dict() optimizer_state['prev_grad_norm'] = prev_grad_norm optimizer.load_state_dict(optimizer_state) print("1. Norm in Previous Epoch:", grad_prev) print("\n") else: prev_grad_norm = optimizer.state_dict()['prev_grad_norm'] print("Norm in Previous Epoch:", prev_grad_norm) optimizer.zero_grad() data = inputs.view(inputs.shape, -1) outputs = net(data) loss = criterion(outputs, labels) loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) for param in net.parameters(): if param.grad is not None: param.grad += torch.randn_like(param.grad) * noise_scale optimizer.step() current_norm = torch.max(torch.tensor([torch.norm(p.grad, 2) for p in net.parameters()])) print("2. Norm in current Epoch:", current_norm) print("\n") # Metrics epoch_loss += loss total += labels.size(0) correct += (torch.max(outputs.data, 1) == labels).sum().item() print("3. Norm updated in current Epoch:", current_norm) print("\n") optimizer.state_dict()['prev_grad_norm'] = current_norm print(optimizer.state_dict()) print("\n") optimizer.step()
Though I see that norm is being calculated at the current step, but when I try to access it through ‘prev_grad_norm’ in the next step, it is showing as None. My purpose of accessing it from previous step is to use that as threshold for clipping in next step.
Why the state is not being passed here?