Custom C++ extension for Optimizer. state_dict not updated

Hello all,

I am trying to build a custom C++ extension with PyTorch.
In particular, I am trying to define a new optimizer step() function, so I started replicating the functionality of SGD optimizer step() as a C++ extension.

Here is the C++ code:

#include <torch/extension.h>
#include <iostream>
#include <vector>

void my_sgd_function(std::vector<torch::Tensor> &params,
                  const std::vector<torch::Tensor> &d_p_list,
                  std::vector<std::optional<torch::Tensor>> &momentum_buffer_list,
                  double weight_decay,
                  double momentum,
                  double lr,
                  double dampening,
                  bool nesterov)
    const size_t num_params = params.size();

    for (size_t i = 0; i < num_params; ++i)
        torch::Tensor d_p = d_p_list[i];
        d_p.add_(params[i], weight_decay);

        if (momentum != 0)
            torch::Tensor buf = torch::empty_like(d_p);

            if (!momentum_buffer_list[i].has_value())
                buf = d_p.clone().detach();
                momentum_buffer_list[i] = buf;
                buf = momentum_buffer_list[i].value();
                buf.mul_(momentum).add_(d_p, 1 - dampening);

            if (nesterov)
                d_p.add_(buf, momentum);
                d_p = buf;

        params[i].add_(d_p, -lr);

// Binding
    m.def("my_sgd", &my_sgd_function, "my sgd computation (CPU)");

My C function tries to mimic the behavior of _single_tensor_sgd().
After compiling and binding, here is the Python code from which I want to use the custom C++ function (extracted from pytorch/ at master · pytorch/pytorch · GitHub):

import torch
import torch.optim._functional as F
from torch import Tensor
from typing import List, Optional
# My module!
import sgd_cpu

class mySGD(optim.Optimizer):
    def __init__(self, params, lr, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, maximize: bool = False):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError(
                "Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")

        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            group.setdefault('maximize', False)

    def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
        has_sparse_grad = False

        for p in group['params']:
            if p.grad is not None:
                if p.grad.is_sparse:
                    has_sparse_grad = True

                state = self.state[p]
                if 'momentum_buffer' not in state:

        return has_sparse_grad

    def step(self, closure=None):
        """Performs a single optimization step.

            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []

            has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)

            # Calling my custom C++ function

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss

The code runs without errors. The problem is that setting the parameter momentum in mySGD has no effect at all. As far as I know, the problem is with the update of the momentum_buffer_list. At the beginning of the training, this parameter is an empty list that should be updated with derivatives d_p_list. This is done in the C code when

momentum_buffer_list[i] = buf;

However, when turning back to Python, the same parameter seems to be the previous list of Nones, so the state['momentum_buffer'] is not updated (I could check this by printing the variables both inside the C code, after assignation, and again in Python, after execution of my_sgd_function).

How could I solve this?
This is a naive example, but my intention is to define a new optimizer as a C++ extension that updates internal parameters as in SGD or Adam.

I realized that in pytorch/sgd.cpp at master · pytorch/pytorch (, the state of the optimizer is used for momentum_buffer update, state->momentum_buffer(buf);, but I don’t know how to apply this in my case, or if I should use that in a C++ extension.

Any suggestion that helps me out here will be thankful.

@ptrblck any insight on how to tackle this?

I don’t know why your code is currently failing and why the momentum_buffer_list is not being updated. It might be easier to write your experiments in Python or did you see significant performance differences in using the C++ implementation?
Note that both would go through the dispatching mechanism and would use the same kernel calls. To reduce the number of kernel calls (if this is adding to the overhead), you could check the fused implementations of a few optimizers (such as Adam) and use e.g. torch._foreach_sub_.

I was able to solve this problem by returning and reassigning the momentum_buffer_list as follows:

In C++ code:

std::vector<std::optional<torch::Tensor>> sgd_function(...)
return momentum_buffer_list;

and then in the Python step function:

momentum_buffer_list = sgd_cpu.sgd(...)

This is strange to me since Python List objects are mutable, so any function with the list as a parameter should modify it. However, maybe it is because of the C++ extension, or the declaration of the list as std::vector<std::optional<torch::Tensor>> &momentum_buffer_list, it is not mutable in this case.
I could not find a lot of information about this, but it seems that the problem was with Python mutable/immutable objects, rather than an incorrect use of libtorch.

@ptrblck the main reason for using C++ extension is not due to the overhead, but cause I want to modify the optimizer using certain operations defined in a third-party C++ library, so porting that to Python is not an option.