Get real 'd_p' value that updates the params of each step

I want to inspect the growing of the params of each step. The value that I want is not exactly gradients which are possible to grab using ‘param.grad’ because if SGD with momentum is applied, the updated value will be

lr * (momentum term 90% + gradient only 10%)

, and I want to have this aggregated value.

I know that it is possible to do by calculating the different between params of current step and previous step. But in fact, this value is calculated by the optimizer function, optimizer.step(), as shown in figure below:

I want to get ‘d_p’ value from second last line, p.data.add_(-group[‘lr’], d_p), because it is the real value after calculated by optimizer that is updated to the network parameters. Is there any way to get it correctly? Unless I need to override the step function by myself.

Hi,

Given that this is very specific to SGD, this is not part of the general optimizer interface.
You will have to modify the .step() function to return what you need.

Hi,

I did inherit the ‘optim.SGD’ class and override the ‘step’ function. This is what I have done.

class SGDExtended(optim.SGD):
    
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        d_ps_groups = []
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            d_ps = []
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)
                d_ps.append(d_p.cpu().detach().numpy())
            d_ps_groups.append(d_ps)

        return loss, d_ps_groups

Thank you very much.