# How to implement SGD with momentum?

I understand that it is a very elementary question. Is something like the following correct?

``````torch.optim.SGD(params, lr=0.01, momentum=0.9)
``````

I ask this because I try to replicate the pytorch lightning tutorial regarding optimizer here. Rather than implementing optimizers from scratch as in the tutorial, I used the function from torch.optim directly. In particular, in , I replaced

``````SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=10, momentum=0.9))
``````

by

``````SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
``````

The result appears to be much worse than indicated in the tutorial. The result for Nesterov accelerated gradient as implemented below appears to be worse than one expects as well.

``````NAG_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9, nesterov=True))
``````

I have a feeling that I could have mess up something stupid. But I can’t really spot it and hopefully someone can help.

I included my complete test code below. Please note that nothing has changed except the few lines to compute SGD_points, SGDMom_points, and Adam_points

``````from matplotlib import cm
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import numpy as np

def pathological_curve_loss(w1, w2):
# Example of a pathological curvature. There are many more possible, feel free to experiment here!
x1_loss = torch.tanh(w1) ** 2 + 0.01 * torch.abs(w1)
x2_loss = torch.sigmoid(w2)
return x1_loss + x2_loss

def plot_curve(
curve_fn, x_range=(-5, 5), y_range=(-5, 5), plot_3d=False, cmap=cm.viridis, title="Pathological curvature"
):
fig = plt.figure()
ax = fig.add_subplot(projection='3d') if plot_3d else fig.gca()
#     ax = fig.gca(projection="3d") if plot_3d else fig.gca()

x = torch.arange(x_range, x_range, (x_range - x_range) / 100.0)
y = torch.arange(y_range, y_range, (y_range - y_range) / 100.0)
x, y = torch.meshgrid([x, y])
z = curve_fn(x, y)
x, y, z = x.numpy(), y.numpy(), z.numpy()

if plot_3d:
ax.plot_surface(x, y, z, cmap=cmap, linewidth=1, color="#000", antialiased=False)
ax.set_zlabel("loss")
else:
ax.imshow(z.T[::-1], cmap=cmap, extent=(x_range, x_range, y_range, y_range))
plt.title(title)
ax.set_xlabel(r"\$w_1\$")
ax.set_ylabel(r"\$w_2\$")
plt.tight_layout()
return ax

# sns.reset_orig()
# _ = plot_curve(pathological_curve_loss, plot_3d=True)
# plt.show()

from torch import nn

def train_curve(optimizer_func, curve_func=pathological_curve_loss, num_updates=100, init=[5, 5]):
"""
Args:
optimizer_func: Constructor of the optimizer to use. Should only take a parameter list
curve_func: Loss function (e.g. pathological curvature)
init: Initial values of parameters. Must be a list/tuple with two elements representing w_1 and w_2
Returns:
Numpy array of shape [num_updates, 3] with [t,:2] being the parameter values at step t, and [t,2] the loss at t.
"""
optim = optimizer_func([weights])

list_points = []
loss = curve_func(weights, weights)
list_points.append(torch.cat([weights.data.detach(), loss.unsqueeze(dim=0).detach()], dim=0))
loss.backward()
optim.step()
points = torch.stack(list_points, dim=0).numpy()
return points

# BEGIN only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html
SGD_points = train_curve(lambda params: torch.optim.SGD(params, lr=10))
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
# END only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html

all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(
pathological_curve_loss,
x_range=(-np.absolute(all_points[:, 0]).max(), np.absolute(all_points[:, 0]).max()),
y_range=(all_points[:, 1].min(), all_points[:, 1].max()),
plot_3d=False,
)
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=1, label="SGD")
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom")