Slowdown over time using momentum

I have created my own torch.optim.Optimizer class which implements a variant of the FrankWolfe (aka conditional gradient descent) method. While initializing the optimizer, I pass a momentum value 0 <= m <= 1. I recently noticed that the optimizer is becoming slower and slower in each step and I cannot figure out why. I did some timings and figured that it is due to my momentum code. The following is the momentum part of my step() function:

@torch.no_grad()
def step(self, closure=None):
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            d_p = p.grad

            # Add momentum
            momentum = group['momentum']
            if momentum > 0:
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    param_state['momentum_buffer'] = d_p.detach().clone()
                else:
                    param_state['momentum_buffer'].mul_(momentum).add_(d_p, alpha=1 - momentum)
                    d_p = param_state['momentum_buffer']
...

In the first iteration of step(), the momentum_buffer is ont set and I hence set it to be the last gradient itself. In later iterations I multiply the old momentum buffer by momentum factor (e.g. 0.9) and add d_p to it, being the p.grad.

What could be the problem here? Do I accumulate history? I am unsure why this would get slower and slower over time, in the beginning it is fairly fast but the runtime is monotonically increasing.

That could be the case. Do you also see an increased memory usage?
If so, could you check, if your momentum buffer has a valid .grad_fn after a couple of iterations?

If so, it could store the computation graph. However, since you have already used the no_grad context manager this should not be the case.

How did you isolate that the momentum calculation is responsible for the slowdown?