Custom optimizer replacing param.add_ in SGD

Hello,

I’m trying to implement a custom SGD optimizer, by changing the way in which d_p * alpha gets added to param. I started out with the following custom implementation which copies SGD code from: https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html#SGD with a small change in _single_tensor_sgd():

def _single_tensor_sgd(
    params,
    d_p_list,
    momentum_buffer_list,
    *,
    weight_decay,
    momentum,
    lr,
    dampening,
    nesterov,
    maximize,
    has_sparse_grad
):
    for i, param in enumerate(params):

        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf

        alpha = lr if maximize else -lr

        # Changed this part
        #param.add_(d_p, alpha=alpha)  # <-- Matches pytorch
        param = param + d_p * alpha   # <-- Doesn't match pytorch

To check whether this change makes any difference, I tested it against Pytorch’s implementation as shown below:

import torch
# MySGD is the same as torch SGD except the _single_tensor_sgd function
from test_SGD import MySGD
import copy

class SimpleLinear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.linear(x)

in_features = 10
out_features = 50

# Create two equivalent models
torch_model = SimpleLinear(in_features, out_features)
my_model = copy.deepcopy(torch_model)
my_model.load_state_dict(torch_model.state_dict())

torch_model.train()
my_model.train()

lr = 0.001

# My optimizer
my_optim = MySGD(my_model.parameters(), lr)
# Pytorch optimizer
torch_optim = torch.optim.SGD(torch_model.parameters(), lr)

# Mock inputs
inputs = torch.randn((2, in_features))

# Performs "training" step
def get_my_opt_state():
    torch.manual_seed(10)
    y = my_model(inputs)
    y.sum().backward()
    my_optim.step()
    my_model.zero_grad()

def get_torch_opt_state():
    torch.manual_seed(10)
    y = torch_model(inputs)
    y.sum().backward()
    torch_optim.step()
    torch_model.zero_grad()

# Compute error between torch and my weights
def compute_abs_error(x, y):
    return torch.sum(torch.abs(x - y))

# Do one mock epoch
for epoch in range(1):
    get_torch_opt_state()
    get_my_opt_state()

# Compare resultant weights
torch_w = torch_model.linear.weight
my_w = my_model.linear.weight

error = compute_abs_error(torch_w, my_w)
print("Error: ", error)

However, with param.add_(d_p, alpha=alpha) there is zero error. Once changed to param = param + d_p * alpha it gives me a non-zero error. What could be causing this difference? Or, is my testing incorrect? I would like to write a custom way of doing the update step, so I want to make sure I don’t cause the same issue there.

Hi lnair!

param = param + d_p * alpha creates a new tensor and sets the
python reference param to refer to the new tensor, rather than modifying
the tensor you wish to train. Conversely, param.add_(d_p, alpha=alpha)
modifies the desired tensor (“adds in place”).

One approach that would keep much of your code unchanged would be:

with torch.no_grad():
    param.copy_ (param + d_p * alpha)

To illustrate the issue of setting param to refer to a new tensor, consider:

>>> import torch
>>> torch.__version__
'1.10.2'
>>> param = torch.ones (2, 3)
>>> orig_param = param
>>> _ = param.add_ (torch.ones (2, 3))
>>> param
tensor([[2., 2., 2.],
        [2., 2., 2.]])
>>> orig_param
tensor([[2., 2., 2.],
        [2., 2., 2.]])
>>> param = param + torch.ones (2, 3)
>>> param
tensor([[3., 3., 3.],
        [3., 3., 3.]])
>>> orig_param
tensor([[2., 2., 2.],
        [2., 2., 2.]])

Best.

K. Frank

Thank you for the response! That clears it up for me. There is a still a very small difference in results (order of e-9), but I can work around that.