I have created my own torch.optim.Optimizer class which implements a variant of the FrankWolfe (aka conditional gradient descent) method. While initializing the optimizer, I pass a momentum value 0 <= m <= 1. I recently noticed that the optimizer is becoming slower and slower in each step and I cannot figure out why. I did some timings and figured that it is due to my momentum code. The following is the momentum part of my step() function:
@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad
# Add momentum
momentum = group['momentum']
if momentum > 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = d_p.detach().clone()
else:
param_state['momentum_buffer'].mul_(momentum).add_(d_p, alpha=1 - momentum)
d_p = param_state['momentum_buffer']
...
In the first iteration of step(), the momentum_buffer is ont set and I hence set it to be the last gradient itself. In later iterations I multiply the old momentum buffer by momentum factor (e.g. 0.9) and add d_p to it, being the p.grad.
What could be the problem here? Do I accumulate history? I am unsure why this would get slower and slower over time, in the beginning it is fairly fast but the runtime is monotonically increasing.