I am calculating the norm of the gradients in current step during optimisation and would like to access it in the next step.I have tried to access it through optimizer state_dict(). Below is my code for the same.
for e in range(epochs):
correct, total, epoch_loss = 0, 0, 0.0
for batch_idx, (inputs, target) in enumerate(loader):
# Initialize grad_prev as a tensor of zeros
if 'prev_grad_norm' not in optimizer.state_dict():
prev_grad_norm = torch.Tensor(100)
optimizer_state = optimizer.state_dict()
optimizer_state['prev_grad_norm'] = prev_grad_norm
optimizer.load_state_dict(optimizer_state)
print("1. Norm in Previous Epoch:", grad_prev)
print("\n")
else:
prev_grad_norm = optimizer.state_dict()['prev_grad_norm']
print("Norm in Previous Epoch:", prev_grad_norm)
optimizer.zero_grad()
data = inputs.view(inputs.shape[0], -1)
outputs = net(data)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
for param in net.parameters():
if param.grad is not None:
param.grad += torch.randn_like(param.grad) * noise_scale
optimizer.step()
current_norm = torch.max(torch.tensor([torch.norm(p.grad, 2) for p in net.parameters()]))
print("2. Norm in current Epoch:", current_norm)
print("\n")
# Metrics
epoch_loss += loss
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
print("3. Norm updated in current Epoch:", current_norm)
print("\n")
optimizer.state_dict()['prev_grad_norm'] = current_norm
print(optimizer.state_dict())
print("\n")
optimizer.step()
Though I see that norm is being calculated at the current step, but when I try to access it through ‘prev_grad_norm’ in the next step, it is showing as None. My purpose of accessing it from previous step is to use that as threshold for clipping in next step.
Why the state is not being passed here?