Exposing named parameters and parameters of a custom module (LibTorch)

I have the following implementation:

#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdint>
#include <cmath>
#include <vector>

class OptimizedACPNImpl : public torch::nn::Module {
private:
    int64_t in_features_;
    int64_t out_features_;
    int64_t rank_;
    int64_t query_dim_;
    bool has_bias_;

    torch::Tensor U_, V_, W_q_, W_k_, bias_, attention_scale_;
    
    // Cache for the key projection
    torch::Tensor cached_key_proj_;
    bool cache_valid_ = false;
    torch::Device last_device_ = torch::Device(torch::kCPU);
    
    // Safety flag for device checking
    bool check_device_ = true;
    bool first_forward_done_ = false;

public:
    OptimizedACPNImpl(int64_t in_features, int64_t out_features, 
                      int64_t rank = -1, int64_t query_dim = 16, bool use_bias = true)
        : in_features_(in_features), out_features_(out_features),
          query_dim_(query_dim), has_bias_(use_bias) {
        
        if (rank < 0) {
            rank_ = static_cast<int64_t>(0.35 * std::min(in_features, out_features));
            rank_ = std::max(static_cast<int64_t>(1), rank_);
        } else {
            rank_ = rank;
        }

        // Register parameters
        U_ = register_parameter("U", torch::zeros({rank_, in_features_}));
        V_ = register_parameter("V", torch::zeros({out_features_, rank_}));
        W_q_ = register_parameter("W_q", torch::zeros({query_dim_, in_features_}));
        W_k_ = register_parameter("W_k", torch::zeros({rank_, query_dim_}));
        attention_scale_ = register_parameter("attention_scale", torch::ones(1));
        
        if (has_bias_) {
            bias_ = register_parameter("bias", torch::zeros(out_features_));
        }

        reset_parameters();
    }

    void reset_parameters() {
        double std_val = 0.02;
        torch::nn::init::normal_(U_, 0.0, std_val);
        torch::nn::init::normal_(V_, 0.0, std_val);
        torch::nn::init::normal_(W_q_, 0.0, std_val);
        torch::nn::init::normal_(W_k_, 0.0, std_val);
        torch::nn::init::constant_(attention_scale_, 0.5);
        if (has_bias_) {
            torch::nn::init::zeros_(bias_);
        }
        
        // Invalidate cache
        cache_valid_ = false;
        first_forward_done_ = false;
    }

    // Method to toggle device safety checks
    void set_device_checking(bool enabled) {
        check_device_ = enabled;
    }

    // Method to check if device checking is enabled
    bool get_device_checking() const {
        return check_device_;
    }

    torch::Tensor compute_key_projection(const torch::Device& device) {
        // Skip cache in training mode to avoid autograd issues
        if (is_training() || !cache_valid_) {
            // Compute the projection
            double scale = 1.0 / std::sqrt(static_cast<double>(query_dim_));
            cached_key_proj_ = torch::mm(W_k_, W_q_) * scale;
            cache_valid_ = !is_training(); // Only cache in eval mode
        }
        return cached_key_proj_;
    }

    std::vector<torch::Tensor> debug_parameters(bool recurse = true) {
        std::cout << "Debug parameters called, recurse=" << recurse << std::endl;
        auto params = parameters(recurse);
        std::cout << "Number of parameters: " << params.size() << std::endl;
        for (size_t i = 0; i < params.size(); ++i) {
            std::cout << "  Param " << i << " device: " << params[i].device() << std::endl;
        }
        return params;
    }

    torch::Tensor forward(const torch::Tensor& x) {
        auto x_cont = x.contiguous();
        
        // Check if this is the first forward pass - essential for initialization
        if (!first_forward_done_) {
            // On first forward, ensure all parameters are on the input device
            U_ = U_.to(x.device());
            V_ = V_.to(x.device());
            W_q_ = W_q_.to(x.device());
            W_k_ = W_k_.to(x.device());
            attention_scale_ = attention_scale_.to(x.device());
            if (has_bias_) {
                bias_ = bias_.to(x.device());
            }
            first_forward_done_ = true;
            last_device_ = x.device();
        }
        
        // Get key projection
        auto key_proj = compute_key_projection(x.device());
        
        // Main computation path - assuming all tensors are on the correct device
        auto main = torch::nn::functional::linear(x_cont, U_);
        auto attn_logits = torch::nn::functional::linear(x_cont, key_proj);
        auto attn_weights = torch::softmax(attn_logits, -1);
        
        
        // Apply attention 
        auto modulated = torch::addcmul(main, main, attention_scale_ * attn_weights);
        
        // Final projection
        return has_bias_ ? 
            torch::nn::functional::linear(modulated, V_, bias_) :
            torch::nn::functional::linear(modulated, V_);
    }

    void train(bool on = true) override {
        torch::nn::Module::train(on);
        if (on) {
            // Invalidate cache when transitioning to training mode
            cache_valid_ = false;
            
            // Re-enable device checking in training mode for safety
            check_device_ = true;
        }
    }

    void eval() {
        torch::nn::Module::eval();
    }

    // Add this to your C++ implementation
    void move_to_device(torch::Device device) {
        std::cout << "move_to_device called for: " << device << std::endl;
        
        U_ = U_.to(device);
        V_ = V_.to(device);
        W_q_ = W_q_.to(device);
        W_k_ = W_k_.to(device);
        attention_scale_ = attention_scale_.to(device);
        if (has_bias_) {
            bias_ = bias_.to(device);
        }
        
        std::cout << "After explicit move: U_ device: " << U_.device() << std::endl;
    }

    void to(torch::Device device, bool non_blocking = false) override {
        // First, call base class implementation (important to do this FIRST)
        torch::nn::Module::to(device, non_blocking);
        
        // Then explicitly move the member variables
        U_ = U_.to(device, non_blocking);
        V_ = V_.to(device, non_blocking);
        W_q_ = W_q_.to(device, non_blocking);
        W_k_ = W_k_.to(device, non_blocking);
        attention_scale_ = attention_scale_.to(device, non_blocking);
        if (has_bias_) {
            bias_ = bias_.to(device, non_blocking);
        }
        
        // Update cache and device tracking
        if (cached_key_proj_.defined()) {
            cached_key_proj_ = cached_key_proj_.to(device, non_blocking);
        }
        cache_valid_ = false;
        first_forward_done_ = false;
        last_device_ = device;
    }

    void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) override {
        std::cout << "OptimizedACPNImpl::to(device, dtype) called" << std::endl;
        
        // Move member tensors to device and dtype
        U_ = U_.to(device, dtype, non_blocking);
        V_ = V_.to(device, dtype, non_blocking);
        W_q_ = W_q_.to(device, dtype, non_blocking);
        W_k_ = W_k_.to(device, dtype, non_blocking);
        attention_scale_ = attention_scale_.to(device, dtype, non_blocking);
        if (has_bias_) {
            bias_ = bias_.to(device, dtype, non_blocking);
        }
        
        // Call base class implementation
        torch::nn::Module::to(device, dtype, non_blocking);
        
        // Reset cache and states
        cache_valid_ = false;
        first_forward_done_ = false;
        last_device_ = device;
    }

    // Getters
    int64_t in_features() const { return in_features_; }
    int64_t out_features() const { return out_features_; }
    int64_t rank() const { return rank_; }
    int64_t query_dim() const { return query_dim_; }
};

TORCH_MODULE(OptimizedACPN);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    py::class_<OptimizedACPNImpl, std::shared_ptr<OptimizedACPNImpl>, torch::nn::Module>(m, "_OptimizedACPNImpl")
        .def(py::init<int64_t, int64_t, int64_t, int64_t, bool>())
        .def("forward", &OptimizedACPNImpl::forward)
        .def("reset_parameters", &OptimizedACPNImpl::reset_parameters)
        .def("train", &OptimizedACPNImpl::train, py::arg("on") = true)
        .def("eval", &OptimizedACPNImpl::eval)
        .def("set_device_checking", &OptimizedACPNImpl::set_device_checking)
        .def("get_device_checking", &OptimizedACPNImpl::get_device_checking)
        .def_property_readonly("in_features", &OptimizedACPNImpl::in_features)
        .def_property_readonly("out_features", &OptimizedACPNImpl::out_features)
        .def_property_readonly("rank", &OptimizedACPNImpl::rank)
        .def_property_readonly("query_dim", &OptimizedACPNImpl::query_dim);

    py::class_<OptimizedACPN, std::shared_ptr<OptimizedACPN>>(m, "OptimizedACPN", py::module_local())
        .def(py::init<int64_t, int64_t, int64_t, int64_t, bool>(),
             py::arg("in_features"),
             py::arg("out_features"),
             py::arg("rank") = -1,
             py::arg("query_dim") = 16,
             py::arg("use_bias") = true)
        .def("forward", [](OptimizedACPN& m, const torch::Tensor& x) {
            return m->forward(x);
        })
        .def("__call__", [](OptimizedACPN& m, const torch::Tensor& x) {
            return m->forward(x);
        })
        .def("reset_parameters", [](OptimizedACPN& m) {
            m->reset_parameters();
        })
        .def("train", [](OptimizedACPN& m, bool on = true) {
            m->train(on);
            return m;
        }, py::arg("on") = true)
        .def("eval", [](OptimizedACPN& m) {
            m->eval();
            return m;
        })
        .def("set_device_checking", [](OptimizedACPN& m, bool enabled) {
            m->set_device_checking(enabled);
            return m;
        })
        .def("get_device_checking", [](OptimizedACPN& m) {
            return m->get_device_checking();
        })
        .def("parameters", [](OptimizedACPN& m, bool recurse = true) {
            return py::make_iterator(m->parameters(recurse));
        }, py::arg("recurse") = true, py::keep_alive<0, 1>())
        
        .def("named_parameters", [](OptimizedACPN& m) {
            std::vector<std::pair<std::string, torch::Tensor>> named_params;
            for (auto& p : m->named_parameters()) {
                named_params.push_back(std::make_pair(p.key(), p.value()));
            }
            return named_params;
        })
        .def("move_to_device", [](OptimizedACPN& m, torch::Device device) {
            m->move_to_device(device);
            return m;
        })
        .def("debug_parameters", [](OptimizedACPN& m) {
            auto params = m->parameters();
            for (size_t i = 0; i < params.size(); ++i) {
                std::cout << "  Parameter " << i << " device: " << params[i].device() << std::endl;
            }
            return py::make_iterator(params.begin(), params.end());
        }, py::keep_alive<0, 1>())
        
        .def("get_param_devices", [](OptimizedACPN& m) {
            std::vector<std::string> devices;
            for (auto& p : m->parameters()) {
                devices.push_back(c10::str(p.device()));
            }
            return devices;
        })
        .def("to", [](OptimizedACPN& m, torch::Device device) {
            std::cout << "Python binding: OptimizedACPN.to(device) called" << std::endl;
            m->to(device);
            return m;
        })
        .def("to", [](OptimizedACPN& m, py::object device) {
            if (py::isinstance<py::str>(device)) {
                // Handle string device
                m->to(torch::Device(py::cast<std::string>(device)));
            } else if (py::hasattr(device, "type")) {
                // Handle torch.device
                std::string type = py::str(device.attr("type"));
                int index = 0;
                if (!device.attr("index").is_none()) {
                    index = py::cast<int>(device.attr("index"));
                }
                
                if (type == "cuda") {
                    m->to(torch::Device(torch::kCUDA, index));
                } else {
                    m->to(torch::Device(torch::kCPU));
                }
            } else if (py::isinstance<py::int_>(device)) {
                // Handle CUDA device index
                m->to(torch::Device(torch::kCUDA, py::cast<int>(device)));
            } else if (py::isinstance<torch::Tensor>(device)) {
                // Handle tensor device
                torch::Tensor tensor = py::cast<torch::Tensor>(device);
                m->to(tensor.device());
            } else {
                throw std::runtime_error("Unsupported device type");
            }
            return m;
        })
        .def_property_readonly("in_features", [](OptimizedACPN& m) {
            return m->in_features();
        })
        .def_property_readonly("out_features", [](OptimizedACPN& m) {
            return m->out_features();
        })
        .def_property_readonly("rank", [](OptimizedACPN& m) {
            return m->rank();
        })
        .def_property_readonly("query_dim", [](OptimizedACPN& m) {
            return m->query_dim();
        })
        .def("__repr__", [](const OptimizedACPN& m) {
            std::ostringstream ss;
            ss << "OptimizedACPN(in_features=" << m->in_features() 
               << ", out_features=" << m->out_features()
               << ", rank=" << m->rank()
               << ", query_dim=" << m->query_dim() << ")";
            return ss.str();
        })
        .def("_get_name", [](const OptimizedACPN&) {
            return "OptimizedACPN";
        });
}

I have a script that outputs a debug log of testing the module in python:

ACPN Parameter Device Debug Script
PyTorch version: 2.6.0+cu126
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3090



===== ACPN Parameter Device Handling Test =====

1. After initialization (should be on CPU):
Model parameters:
  U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
  bias: shape=torch.Size([128]), device=cpu, requires_grad=True
  ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']

2. Try directly accessing attributes via Python:
  Direct U_ access: Not accessible
  Direct V_ access: Not accessible

3. Testing .to() method (string):
  After .to('cpu'): Model parameters:
  U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
  bias: shape=torch.Size([128]), device=cpu, requires_grad=True
  ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']

4. Moving to CUDA with string:
  After .to('cuda'): Model parameters:
  U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
  bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
  ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']

5. Moving back to CPU:
  After .to('cpu'): Model parameters:
  U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
  bias: shape=torch.Size([128]), device=cpu, requires_grad=True
  ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']

6. Moving to CUDA with torch.device:
Python binding: OptimizedACPN.to(device) called
  After .to(device): Model parameters:
  U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
  bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
  ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']

7. Using custom move_to_device method:
move_to_device called for: cpu
After explicit move: U_ device: cpu
  After move_to_device(cpu): Model parameters:
  U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
  bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
  ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
move_to_device called for: cuda
After explicit move: U_ device: cuda:0
  After move_to_device(cuda): Model parameters:
  U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
  bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
  ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']

8. Testing forward pass with tensor on CUDA:
  Forward pass successful: output shape=torch.Size([8, 128]), device=cuda:0

===== Comparing with PyTorch Linear =====

1. After initialization:
ACPN parameters:
  Model parameters:
  U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
  bias: shape=torch.Size([128]), device=cpu, requires_grad=True
  ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']

PyTorch Linear parameters:
  Model parameters:
  weight: shape=torch.Size([128, 64]), device=cpu, requires_grad=True
  bias: shape=torch.Size([128]), device=cpu, requires_grad=True

2. After moving to CUDA:
ACPN parameters:
  Model parameters:
  U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
  V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
  W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
  W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
  attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
  bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
  ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']

PyTorch Linear parameters:
  Model parameters:
  weight: shape=torch.Size([128, 64]), device=cuda:0, requires_grad=True
  bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True

3. Forward pass on CUDA:
  ACPN output: shape=torch.Size([8, 128]), device=cuda:0
  Linear output: shape=torch.Size([8, 128]), device=cuda:0

===== Parameter Registration and Iteration Test =====

1. Count parameters:
  Number of parameters from direct call: 6

2. Parameter iteration check:
  Param 0: shape=torch.Size([22, 64]), device=cpu
  Param 1: shape=torch.Size([128, 22]), device=cpu
  Param 2: shape=torch.Size([16, 64]), device=cpu
  Param 3: shape=torch.Size([22, 16]), device=cpu
  Param 4: shape=torch.Size([1]), device=cpu
  Param 5: shape=torch.Size([128]), device=cpu

3. Named parameters check:
  Number of named parameters: 6
  U: shape=torch.Size([22, 64]), device=cpu
  V: shape=torch.Size([128, 22]), device=cpu
  W_q: shape=torch.Size([16, 64]), device=cpu
  W_k: shape=torch.Size([22, 16]), device=cpu
  attention_scale: shape=torch.Size([1]), device=cpu
  bias: shape=torch.Size([128]), device=cpu

4. Parameter propagation in container:
Traceback (most recent call last):
  File "c:\AIProjects\ACPNC\acpn_model_test.py", line 190, in <module>
    test_parameter_registration_and_iteration()
  File "c:\AIProjects\ACPNC\acpn_model_test.py", line 152, in test_parameter_registration_and_iteration
    container = nn.Sequential(model)
  File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\nn\modules\container.py", line 127, in __init__
    self.add_module(str(idx), module)
  File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\nn\modules\module.py", line 645, in add_module
    raise TypeError(f"{torch.typename(module)} is not a Module subclass")
TypeError: optimized_acpn.OptimizedACPN is not a Module subclass

I cant seem to get the module to register properly as a PyToch Module. I had to implement the .to() method manually and the knock on effect is I can’t use in nn.Sequential and when I use this layer in part of a larger module it cant register it’s parameters for training.

class ACPNAutoencoder(nn.Module):
    def __init__(self):
        super(ACPNAutoencoder, self).__init__()
        # Use OptimizedACPN directly without any wrapper
        self.encoder = OptimizedACPN(784, 128)
        self.decoder = OptimizedACPN(128, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # Flatten the 28x28 image to a vector of 784
        x = x.view(x.size(0), -1)
        x = self.encoder(x)
        x = self.relu(x)
        x = self.decoder(x)
        x = self.sigmoid(x)
        # Reshape back to image
        x = x.view(-1, 1, 28, 28)
        return x

model_acpn = ACPNAutoencoder().to(device)

print(f"ACPN model parameters count: {sum(1 for _ in model_acpn.parameters())}")
print("ACPN model parameters:")
for name, param in model_acpn.named_parameters():
    print(f"  {name}: {param.shape}")

optimizer_acpn = optim.SGD(acpn_params, lr=lr)

Logs:

Using device: cuda
ACPN model parameters count: 0
ACPN model parameters:
Number of parameters for ACPN optimizer: 0
Number of parameters for Linear optimizer: 4
Traceback (most recent call last):
  File "c:\AIProjects\ACPNC\acpn_autoencoder.py", line 88, in <module>
    optimizer_acpn = optim.SGD(acpn_params, lr=lr)  # Try SGD instead of Adam
  File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\optim\sgd.py", line 63, in __init__
    super().__init__(params, defaults)
  File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\optim\optimizer.py", line 372, in __init__
    raise ValueError("optimizer got an empty parameter list")
ValueError: optimizer got an empty parameter list

Can anyone help me understand the correct approach to do this or point me to working implementation on GitHub using Extensions like I am?