How to implement SGD with momentum?

I understand that it is a very elementary question. Is something like the following correct?

torch.optim.SGD(params, lr=0.01, momentum=0.9)

I ask this because I try to replicate the pytorch lightning tutorial regarding optimizer here. Rather than implementing optimizers from scratch as in the tutorial, I used the function from torch.optim directly. In particular, in [32], I replaced

SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=10, momentum=0.9))


SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))

The result appears to be much worse than indicated in the tutorial. The result for Nesterov accelerated gradient as implemented below appears to be worse than one expects as well.

NAG_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9, nesterov=True))

I have a feeling that I could have mess up something stupid. But I can’t really spot it and hopefully someone can help.

Disclaimer: I also cross-posted on stackoverflow (I can’t add a link to there as I can’t add more than two hyperlinks) but didn’t get a reply yet. I later realize that this probably is a more relevant channel for the question.

I included my complete test code below. Please note that nothing has changed except the few lines to compute SGD_points, SGDMom_points, and Adam_points

from matplotlib import cm
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import numpy as np

def pathological_curve_loss(w1, w2):
    # Example of a pathological curvature. There are many more possible, feel free to experiment here!
    x1_loss = torch.tanh(w1) ** 2 + 0.01 * torch.abs(w1)
    x2_loss = torch.sigmoid(w2)
    return x1_loss + x2_loss

def plot_curve(
    curve_fn, x_range=(-5, 5), y_range=(-5, 5), plot_3d=False, cmap=cm.viridis, title="Pathological curvature"
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d') if plot_3d else fig.gca()
#     ax = fig.gca(projection="3d") if plot_3d else fig.gca()

    x = torch.arange(x_range[0], x_range[1], (x_range[1] - x_range[0]) / 100.0)
    y = torch.arange(y_range[0], y_range[1], (y_range[1] - y_range[0]) / 100.0)
    x, y = torch.meshgrid([x, y])
    z = curve_fn(x, y)
    x, y, z = x.numpy(), y.numpy(), z.numpy()

    if plot_3d:
        ax.plot_surface(x, y, z, cmap=cmap, linewidth=1, color="#000", antialiased=False)
        ax.imshow(z.T[::-1], cmap=cmap, extent=(x_range[0], x_range[1], y_range[0], y_range[1]))
    return ax

# sns.reset_orig()
# _ = plot_curve(pathological_curve_loss, plot_3d=True)

from torch import nn

def train_curve(optimizer_func, curve_func=pathological_curve_loss, num_updates=100, init=[5, 5]):
        optimizer_func: Constructor of the optimizer to use. Should only take a parameter list
        curve_func: Loss function (e.g. pathological curvature)
        num_updates: Number of updates/steps to take when optimizing
        init: Initial values of parameters. Must be a list/tuple with two elements representing w_1 and w_2
        Numpy array of shape [num_updates, 3] with [t,:2] being the parameter values at step t, and [t,2] the loss at t.
    weights = nn.Parameter(torch.FloatTensor(init), requires_grad=True)
    optim = optimizer_func([weights])

    list_points = []
    for _ in range(num_updates):
        loss = curve_func(weights[0], weights[1])
        list_points.append([, loss.unsqueeze(dim=0).detach()], dim=0))
    points = torch.stack(list_points, dim=0).numpy()
    return points

# BEGIN only place changed from
SGD_points = train_curve(lambda params: torch.optim.SGD(params, lr=10))
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
Adam_points = train_curve(lambda params: torch.optim.Adam(params, lr=1))
# END only place changed from

all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(
    x_range=(-np.absolute(all_points[:, 0]).max(), np.absolute(all_points[:, 0]).max()),
    y_range=(all_points[:, 1].min(), all_points[:, 1].max()),
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=1, label="SGD")
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom")
ax.plot(Adam_points[:, 0], Adam_points[:, 1], color="grey", marker="o", zorder=3, label="Adam")