I understand that it is a very elementary question. Is something like the following correct?
torch.optim.SGD(params, lr=0.01, momentum=0.9)
I ask this because I try to replicate the pytorch lightning tutorial regarding optimizer here. Rather than implementing optimizers from scratch as in the tutorial, I used the function from torch.optim directly. In particular, in [32], I replaced
SGDMom_points = train_curve(lambda params: SGDMomentum(params, lr=10, momentum=0.9))
by
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
The result appears to be much worse than indicated in the tutorial. The result for Nesterov accelerated gradient as implemented below appears to be worse than one expects as well.
NAG_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9, nesterov=True))
I have a feeling that I could have mess up something stupid. But I can’t really spot it and hopefully someone can help.
Disclaimer: I also cross-posted on stackoverflow (I can’t add a link to there as I can’t add more than two hyperlinks) but didn’t get a reply yet. I later realize that this probably is a more relevant channel for the question.
I included my complete test code below. Please note that nothing has changed except the few lines to compute SGD_points, SGDMom_points, and Adam_points
from matplotlib import cm
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import numpy as np
def pathological_curve_loss(w1, w2):
# Example of a pathological curvature. There are many more possible, feel free to experiment here!
x1_loss = torch.tanh(w1) ** 2 + 0.01 * torch.abs(w1)
x2_loss = torch.sigmoid(w2)
return x1_loss + x2_loss
def plot_curve(
curve_fn, x_range=(-5, 5), y_range=(-5, 5), plot_3d=False, cmap=cm.viridis, title="Pathological curvature"
):
fig = plt.figure()
ax = fig.add_subplot(projection='3d') if plot_3d else fig.gca()
# ax = fig.gca(projection="3d") if plot_3d else fig.gca()
x = torch.arange(x_range[0], x_range[1], (x_range[1] - x_range[0]) / 100.0)
y = torch.arange(y_range[0], y_range[1], (y_range[1] - y_range[0]) / 100.0)
x, y = torch.meshgrid([x, y])
z = curve_fn(x, y)
x, y, z = x.numpy(), y.numpy(), z.numpy()
if plot_3d:
ax.plot_surface(x, y, z, cmap=cmap, linewidth=1, color="#000", antialiased=False)
ax.set_zlabel("loss")
else:
ax.imshow(z.T[::-1], cmap=cmap, extent=(x_range[0], x_range[1], y_range[0], y_range[1]))
plt.title(title)
ax.set_xlabel(r"$w_1$")
ax.set_ylabel(r"$w_2$")
plt.tight_layout()
return ax
# sns.reset_orig()
# _ = plot_curve(pathological_curve_loss, plot_3d=True)
# plt.show()
from torch import nn
def train_curve(optimizer_func, curve_func=pathological_curve_loss, num_updates=100, init=[5, 5]):
"""
Args:
optimizer_func: Constructor of the optimizer to use. Should only take a parameter list
curve_func: Loss function (e.g. pathological curvature)
num_updates: Number of updates/steps to take when optimizing
init: Initial values of parameters. Must be a list/tuple with two elements representing w_1 and w_2
Returns:
Numpy array of shape [num_updates, 3] with [t,:2] being the parameter values at step t, and [t,2] the loss at t.
"""
weights = nn.Parameter(torch.FloatTensor(init), requires_grad=True)
optim = optimizer_func([weights])
list_points = []
for _ in range(num_updates):
loss = curve_func(weights[0], weights[1])
list_points.append(torch.cat([weights.data.detach(), loss.unsqueeze(dim=0).detach()], dim=0))
optim.zero_grad()
loss.backward()
optim.step()
points = torch.stack(list_points, dim=0).numpy()
return points
# BEGIN only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html
SGD_points = train_curve(lambda params: torch.optim.SGD(params, lr=10))
SGDMom_points = train_curve(lambda params: torch.optim.SGD(params, lr=10, momentum=0.9))
Adam_points = train_curve(lambda params: torch.optim.Adam(params, lr=1))
# END only place changed from https://pytorch-lightning.readthedocs.io/en/stable/deploy/production_intermediate.html
all_points = np.concatenate([SGD_points, SGDMom_points, Adam_points], axis=0)
ax = plot_curve(
pathological_curve_loss,
x_range=(-np.absolute(all_points[:, 0]).max(), np.absolute(all_points[:, 0]).max()),
y_range=(all_points[:, 1].min(), all_points[:, 1].max()),
plot_3d=False,
)
ax.plot(SGD_points[:, 0], SGD_points[:, 1], color="red", marker="o", zorder=1, label="SGD")
ax.plot(SGDMom_points[:, 0], SGDMom_points[:, 1], color="blue", marker="o", zorder=2, label="SGDMom")
ax.plot(Adam_points[:, 0], Adam_points[:, 1], color="grey", marker="o", zorder=3, label="Adam")
plt.legend()
plt.show()