Use .parameter method in Python on a c++ nn::Module exposed through torch::jit::CustomClassHolder

Hello, I want to optimize in Python the weights of a Torch module implemented in c++. Here is the idea, I generate a dataset yrand = Wsol * xrand, implement in c++ a module that computes W * x product and I want to find the best W that minimise || W * xrand - yrand ||

%%%%%%%%%%%
%myTorchClass.cpp%
%%%%%%%%%%%

struct MyTorchClass : torch::jit::CustomClassHolder
{
    torch::Tensor g(torch::Tensor x) const
    {
        return m_W.mm(x);
    }

    void init()
    {
        m_W = torch::zeros({ 3, 3}, torch::TensorOptions().dtype(torch::kFloat64).requires_grad(true));
        m_module.register_parameter("W", m_W, true);
    }

    // Expose a method to get parameters, callable from Python
    at::Tensor& parameters() const {
        const auto& ret = m_module.named_parameters();
        return m_module.named_parameters()["W"];
    }

    torch::Tensor m_W;
    torch::nn::Module m_module;
};

TORCH_LIBRARY(mynamespace, m)
{
    m.class_<MyTorchClass>("myclassname")
        .def(torch::init<>())
        .def("init", &MyTorchClass::init)
        .def("g", &MyTorchClass::g)
        .def("parameters", MyTorchClass::parameters);
}

%%%%%%%
% main.py%%
%%%%%%%

xrand = torch.rand((3, 1000), dtype=torch.float64)
Wsol = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=torch.float64)
yrand = torch.matmul(Wsol, xrand)

class myObjective(torch.nn.Module):
    def __init__(self, inputs: torch.Tensor, targets: torch.Tensor):
        super().__init__()
        self.my_ops = my_torch_ops()
        self.inputs = inputs
        self.targets = targets
    def forward(self, x):
        return (self.my_ops.g(self.inputs, self.inputs) - self.targets).norm()

my_objective = myObjective(xrand, yrand)

optimizer = torch.optim.SGD(params=my_objective.my_ops.parameters, lr=0.01)

# Optimization loop
for i in range(1000):
    # Zero the gradients
    optimizer.zero_grad()

    # Compute the value of the objective function
    y = my_objective.forward()

    # Compute the gradients
    y.backward()

    # Update the parameters
    optimizer.step()

I validated that loading custom class mechanism through ‘torch.ops.load_library’ works as intended. The .forward() of my_objective works as well.
But I am missing something regarding the way to manage parameters. This code crashes at

optimizer = torch.optim.SGD(params=my_objective.my_ops.parameters, lr=0.01)

with the error

  File "main.py", line 61, in <module>
    optimizer = torch.optim.SGD(params=my_objective.my_ops.parameters, lr=0.01)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/optim/sgd.py", line 27, in __init__
    super().__init__(params, defaults)
  File ".../lib/python3.11/site-packages/torch/optim/optimizer.py", line 185, in __init__
    param_groups = list(params)
                   ^^^^^^^^^^^^
TypeError: 'torch._C.ScriptMethod' object is not iterable

Is that kind of design viable and how to fix this issue? Thx

Try to use optimizer = torch.optim.SGD(params=my_objective.my_ops.parameters(), lr=0.01) (note the call into .parameters()).

Thank you @ptrblck for your fast reply, you’re right

optimizer = torch.optim.SGD(params=[my_objective.rnn_ops.rnn_obj.parameters()], lr=0.01)

fixed the issue and the optimization gives me the Wsol !

When I started working on that python / c++ module interaction I wasn’t sure about the overall design. What I understand is that my cpp class has to herit from jit:CustomClassHolder to be loaded in Python. It’s cool but lacks the handy ‘.parameter’ mechanism of nn::Module.
So then I added a nn::Module member to the class and a parameters() method to expose the parameters() method of a nn::Module.
Isn’t there a more direct way to expose a nn::Module, using multiple inheritance in MyTorchClass for example (I know there is a torch::jit::Module), or anything else ?