I have the following implementation:
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdint>
#include <cmath>
#include <vector>
class OptimizedACPNImpl : public torch::nn::Module {
private:
int64_t in_features_;
int64_t out_features_;
int64_t rank_;
int64_t query_dim_;
bool has_bias_;
torch::Tensor U_, V_, W_q_, W_k_, bias_, attention_scale_;
// Cache for the key projection
torch::Tensor cached_key_proj_;
bool cache_valid_ = false;
torch::Device last_device_ = torch::Device(torch::kCPU);
// Safety flag for device checking
bool check_device_ = true;
bool first_forward_done_ = false;
public:
OptimizedACPNImpl(int64_t in_features, int64_t out_features,
int64_t rank = -1, int64_t query_dim = 16, bool use_bias = true)
: in_features_(in_features), out_features_(out_features),
query_dim_(query_dim), has_bias_(use_bias) {
if (rank < 0) {
rank_ = static_cast<int64_t>(0.35 * std::min(in_features, out_features));
rank_ = std::max(static_cast<int64_t>(1), rank_);
} else {
rank_ = rank;
}
// Register parameters
U_ = register_parameter("U", torch::zeros({rank_, in_features_}));
V_ = register_parameter("V", torch::zeros({out_features_, rank_}));
W_q_ = register_parameter("W_q", torch::zeros({query_dim_, in_features_}));
W_k_ = register_parameter("W_k", torch::zeros({rank_, query_dim_}));
attention_scale_ = register_parameter("attention_scale", torch::ones(1));
if (has_bias_) {
bias_ = register_parameter("bias", torch::zeros(out_features_));
}
reset_parameters();
}
void reset_parameters() {
double std_val = 0.02;
torch::nn::init::normal_(U_, 0.0, std_val);
torch::nn::init::normal_(V_, 0.0, std_val);
torch::nn::init::normal_(W_q_, 0.0, std_val);
torch::nn::init::normal_(W_k_, 0.0, std_val);
torch::nn::init::constant_(attention_scale_, 0.5);
if (has_bias_) {
torch::nn::init::zeros_(bias_);
}
// Invalidate cache
cache_valid_ = false;
first_forward_done_ = false;
}
// Method to toggle device safety checks
void set_device_checking(bool enabled) {
check_device_ = enabled;
}
// Method to check if device checking is enabled
bool get_device_checking() const {
return check_device_;
}
torch::Tensor compute_key_projection(const torch::Device& device) {
// Skip cache in training mode to avoid autograd issues
if (is_training() || !cache_valid_) {
// Compute the projection
double scale = 1.0 / std::sqrt(static_cast<double>(query_dim_));
cached_key_proj_ = torch::mm(W_k_, W_q_) * scale;
cache_valid_ = !is_training(); // Only cache in eval mode
}
return cached_key_proj_;
}
std::vector<torch::Tensor> debug_parameters(bool recurse = true) {
std::cout << "Debug parameters called, recurse=" << recurse << std::endl;
auto params = parameters(recurse);
std::cout << "Number of parameters: " << params.size() << std::endl;
for (size_t i = 0; i < params.size(); ++i) {
std::cout << " Param " << i << " device: " << params[i].device() << std::endl;
}
return params;
}
torch::Tensor forward(const torch::Tensor& x) {
auto x_cont = x.contiguous();
// Check if this is the first forward pass - essential for initialization
if (!first_forward_done_) {
// On first forward, ensure all parameters are on the input device
U_ = U_.to(x.device());
V_ = V_.to(x.device());
W_q_ = W_q_.to(x.device());
W_k_ = W_k_.to(x.device());
attention_scale_ = attention_scale_.to(x.device());
if (has_bias_) {
bias_ = bias_.to(x.device());
}
first_forward_done_ = true;
last_device_ = x.device();
}
// Get key projection
auto key_proj = compute_key_projection(x.device());
// Main computation path - assuming all tensors are on the correct device
auto main = torch::nn::functional::linear(x_cont, U_);
auto attn_logits = torch::nn::functional::linear(x_cont, key_proj);
auto attn_weights = torch::softmax(attn_logits, -1);
// Apply attention
auto modulated = torch::addcmul(main, main, attention_scale_ * attn_weights);
// Final projection
return has_bias_ ?
torch::nn::functional::linear(modulated, V_, bias_) :
torch::nn::functional::linear(modulated, V_);
}
void train(bool on = true) override {
torch::nn::Module::train(on);
if (on) {
// Invalidate cache when transitioning to training mode
cache_valid_ = false;
// Re-enable device checking in training mode for safety
check_device_ = true;
}
}
void eval() {
torch::nn::Module::eval();
}
// Add this to your C++ implementation
void move_to_device(torch::Device device) {
std::cout << "move_to_device called for: " << device << std::endl;
U_ = U_.to(device);
V_ = V_.to(device);
W_q_ = W_q_.to(device);
W_k_ = W_k_.to(device);
attention_scale_ = attention_scale_.to(device);
if (has_bias_) {
bias_ = bias_.to(device);
}
std::cout << "After explicit move: U_ device: " << U_.device() << std::endl;
}
void to(torch::Device device, bool non_blocking = false) override {
// First, call base class implementation (important to do this FIRST)
torch::nn::Module::to(device, non_blocking);
// Then explicitly move the member variables
U_ = U_.to(device, non_blocking);
V_ = V_.to(device, non_blocking);
W_q_ = W_q_.to(device, non_blocking);
W_k_ = W_k_.to(device, non_blocking);
attention_scale_ = attention_scale_.to(device, non_blocking);
if (has_bias_) {
bias_ = bias_.to(device, non_blocking);
}
// Update cache and device tracking
if (cached_key_proj_.defined()) {
cached_key_proj_ = cached_key_proj_.to(device, non_blocking);
}
cache_valid_ = false;
first_forward_done_ = false;
last_device_ = device;
}
void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) override {
std::cout << "OptimizedACPNImpl::to(device, dtype) called" << std::endl;
// Move member tensors to device and dtype
U_ = U_.to(device, dtype, non_blocking);
V_ = V_.to(device, dtype, non_blocking);
W_q_ = W_q_.to(device, dtype, non_blocking);
W_k_ = W_k_.to(device, dtype, non_blocking);
attention_scale_ = attention_scale_.to(device, dtype, non_blocking);
if (has_bias_) {
bias_ = bias_.to(device, dtype, non_blocking);
}
// Call base class implementation
torch::nn::Module::to(device, dtype, non_blocking);
// Reset cache and states
cache_valid_ = false;
first_forward_done_ = false;
last_device_ = device;
}
// Getters
int64_t in_features() const { return in_features_; }
int64_t out_features() const { return out_features_; }
int64_t rank() const { return rank_; }
int64_t query_dim() const { return query_dim_; }
};
TORCH_MODULE(OptimizedACPN);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<OptimizedACPNImpl, std::shared_ptr<OptimizedACPNImpl>, torch::nn::Module>(m, "_OptimizedACPNImpl")
.def(py::init<int64_t, int64_t, int64_t, int64_t, bool>())
.def("forward", &OptimizedACPNImpl::forward)
.def("reset_parameters", &OptimizedACPNImpl::reset_parameters)
.def("train", &OptimizedACPNImpl::train, py::arg("on") = true)
.def("eval", &OptimizedACPNImpl::eval)
.def("set_device_checking", &OptimizedACPNImpl::set_device_checking)
.def("get_device_checking", &OptimizedACPNImpl::get_device_checking)
.def_property_readonly("in_features", &OptimizedACPNImpl::in_features)
.def_property_readonly("out_features", &OptimizedACPNImpl::out_features)
.def_property_readonly("rank", &OptimizedACPNImpl::rank)
.def_property_readonly("query_dim", &OptimizedACPNImpl::query_dim);
py::class_<OptimizedACPN, std::shared_ptr<OptimizedACPN>>(m, "OptimizedACPN", py::module_local())
.def(py::init<int64_t, int64_t, int64_t, int64_t, bool>(),
py::arg("in_features"),
py::arg("out_features"),
py::arg("rank") = -1,
py::arg("query_dim") = 16,
py::arg("use_bias") = true)
.def("forward", [](OptimizedACPN& m, const torch::Tensor& x) {
return m->forward(x);
})
.def("__call__", [](OptimizedACPN& m, const torch::Tensor& x) {
return m->forward(x);
})
.def("reset_parameters", [](OptimizedACPN& m) {
m->reset_parameters();
})
.def("train", [](OptimizedACPN& m, bool on = true) {
m->train(on);
return m;
}, py::arg("on") = true)
.def("eval", [](OptimizedACPN& m) {
m->eval();
return m;
})
.def("set_device_checking", [](OptimizedACPN& m, bool enabled) {
m->set_device_checking(enabled);
return m;
})
.def("get_device_checking", [](OptimizedACPN& m) {
return m->get_device_checking();
})
.def("parameters", [](OptimizedACPN& m, bool recurse = true) {
return py::make_iterator(m->parameters(recurse));
}, py::arg("recurse") = true, py::keep_alive<0, 1>())
.def("named_parameters", [](OptimizedACPN& m) {
std::vector<std::pair<std::string, torch::Tensor>> named_params;
for (auto& p : m->named_parameters()) {
named_params.push_back(std::make_pair(p.key(), p.value()));
}
return named_params;
})
.def("move_to_device", [](OptimizedACPN& m, torch::Device device) {
m->move_to_device(device);
return m;
})
.def("debug_parameters", [](OptimizedACPN& m) {
auto params = m->parameters();
for (size_t i = 0; i < params.size(); ++i) {
std::cout << " Parameter " << i << " device: " << params[i].device() << std::endl;
}
return py::make_iterator(params.begin(), params.end());
}, py::keep_alive<0, 1>())
.def("get_param_devices", [](OptimizedACPN& m) {
std::vector<std::string> devices;
for (auto& p : m->parameters()) {
devices.push_back(c10::str(p.device()));
}
return devices;
})
.def("to", [](OptimizedACPN& m, torch::Device device) {
std::cout << "Python binding: OptimizedACPN.to(device) called" << std::endl;
m->to(device);
return m;
})
.def("to", [](OptimizedACPN& m, py::object device) {
if (py::isinstance<py::str>(device)) {
// Handle string device
m->to(torch::Device(py::cast<std::string>(device)));
} else if (py::hasattr(device, "type")) {
// Handle torch.device
std::string type = py::str(device.attr("type"));
int index = 0;
if (!device.attr("index").is_none()) {
index = py::cast<int>(device.attr("index"));
}
if (type == "cuda") {
m->to(torch::Device(torch::kCUDA, index));
} else {
m->to(torch::Device(torch::kCPU));
}
} else if (py::isinstance<py::int_>(device)) {
// Handle CUDA device index
m->to(torch::Device(torch::kCUDA, py::cast<int>(device)));
} else if (py::isinstance<torch::Tensor>(device)) {
// Handle tensor device
torch::Tensor tensor = py::cast<torch::Tensor>(device);
m->to(tensor.device());
} else {
throw std::runtime_error("Unsupported device type");
}
return m;
})
.def_property_readonly("in_features", [](OptimizedACPN& m) {
return m->in_features();
})
.def_property_readonly("out_features", [](OptimizedACPN& m) {
return m->out_features();
})
.def_property_readonly("rank", [](OptimizedACPN& m) {
return m->rank();
})
.def_property_readonly("query_dim", [](OptimizedACPN& m) {
return m->query_dim();
})
.def("__repr__", [](const OptimizedACPN& m) {
std::ostringstream ss;
ss << "OptimizedACPN(in_features=" << m->in_features()
<< ", out_features=" << m->out_features()
<< ", rank=" << m->rank()
<< ", query_dim=" << m->query_dim() << ")";
return ss.str();
})
.def("_get_name", [](const OptimizedACPN&) {
return "OptimizedACPN";
});
}
I have a script that outputs a debug log of testing the module in python:
ACPN Parameter Device Debug Script
PyTorch version: 2.6.0+cu126
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3090
===== ACPN Parameter Device Handling Test =====
1. After initialization (should be on CPU):
Model parameters:
U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
bias: shape=torch.Size([128]), device=cpu, requires_grad=True
ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']
2. Try directly accessing attributes via Python:
Direct U_ access: Not accessible
Direct V_ access: Not accessible
3. Testing .to() method (string):
After .to('cpu'): Model parameters:
U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
bias: shape=torch.Size([128]), device=cpu, requires_grad=True
ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']
4. Moving to CUDA with string:
After .to('cuda'): Model parameters:
U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
5. Moving back to CPU:
After .to('cpu'): Model parameters:
U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
bias: shape=torch.Size([128]), device=cpu, requires_grad=True
ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']
6. Moving to CUDA with torch.device:
Python binding: OptimizedACPN.to(device) called
After .to(device): Model parameters:
U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
7. Using custom move_to_device method:
move_to_device called for: cpu
After explicit move: U_ device: cpu
After move_to_device(cpu): Model parameters:
U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
move_to_device called for: cuda
After explicit move: U_ device: cuda:0
After move_to_device(cuda): Model parameters:
U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
8. Testing forward pass with tensor on CUDA:
Forward pass successful: output shape=torch.Size([8, 128]), device=cuda:0
===== Comparing with PyTorch Linear =====
1. After initialization:
ACPN parameters:
Model parameters:
U: shape=torch.Size([22, 64]), device=cpu, requires_grad=True
V: shape=torch.Size([128, 22]), device=cpu, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cpu, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cpu, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cpu, requires_grad=True
bias: shape=torch.Size([128]), device=cpu, requires_grad=True
ACPN internal devices: ['cpu', 'cpu', 'cpu', 'cpu', 'cpu', 'cpu']
PyTorch Linear parameters:
Model parameters:
weight: shape=torch.Size([128, 64]), device=cpu, requires_grad=True
bias: shape=torch.Size([128]), device=cpu, requires_grad=True
2. After moving to CUDA:
ACPN parameters:
Model parameters:
U: shape=torch.Size([22, 64]), device=cuda:0, requires_grad=True
V: shape=torch.Size([128, 22]), device=cuda:0, requires_grad=True
W_q: shape=torch.Size([16, 64]), device=cuda:0, requires_grad=True
W_k: shape=torch.Size([22, 16]), device=cuda:0, requires_grad=True
attention_scale: shape=torch.Size([1]), device=cuda:0, requires_grad=True
bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
ACPN internal devices: ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
PyTorch Linear parameters:
Model parameters:
weight: shape=torch.Size([128, 64]), device=cuda:0, requires_grad=True
bias: shape=torch.Size([128]), device=cuda:0, requires_grad=True
3. Forward pass on CUDA:
ACPN output: shape=torch.Size([8, 128]), device=cuda:0
Linear output: shape=torch.Size([8, 128]), device=cuda:0
===== Parameter Registration and Iteration Test =====
1. Count parameters:
Number of parameters from direct call: 6
2. Parameter iteration check:
Param 0: shape=torch.Size([22, 64]), device=cpu
Param 1: shape=torch.Size([128, 22]), device=cpu
Param 2: shape=torch.Size([16, 64]), device=cpu
Param 3: shape=torch.Size([22, 16]), device=cpu
Param 4: shape=torch.Size([1]), device=cpu
Param 5: shape=torch.Size([128]), device=cpu
3. Named parameters check:
Number of named parameters: 6
U: shape=torch.Size([22, 64]), device=cpu
V: shape=torch.Size([128, 22]), device=cpu
W_q: shape=torch.Size([16, 64]), device=cpu
W_k: shape=torch.Size([22, 16]), device=cpu
attention_scale: shape=torch.Size([1]), device=cpu
bias: shape=torch.Size([128]), device=cpu
4. Parameter propagation in container:
Traceback (most recent call last):
File "c:\AIProjects\ACPNC\acpn_model_test.py", line 190, in <module>
test_parameter_registration_and_iteration()
File "c:\AIProjects\ACPNC\acpn_model_test.py", line 152, in test_parameter_registration_and_iteration
container = nn.Sequential(model)
File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\nn\modules\container.py", line 127, in __init__
self.add_module(str(idx), module)
File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\nn\modules\module.py", line 645, in add_module
raise TypeError(f"{torch.typename(module)} is not a Module subclass")
TypeError: optimized_acpn.OptimizedACPN is not a Module subclass
I cant seem to get the module to register properly as a PyToch Module. I had to implement the .to() method manually and the knock on effect is I can’t use in nn.Sequential and when I use this layer in part of a larger module it cant register it’s parameters for training.
class ACPNAutoencoder(nn.Module):
def __init__(self):
super(ACPNAutoencoder, self).__init__()
# Use OptimizedACPN directly without any wrapper
self.encoder = OptimizedACPN(784, 128)
self.decoder = OptimizedACPN(128, 784)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Flatten the 28x28 image to a vector of 784
x = x.view(x.size(0), -1)
x = self.encoder(x)
x = self.relu(x)
x = self.decoder(x)
x = self.sigmoid(x)
# Reshape back to image
x = x.view(-1, 1, 28, 28)
return x
model_acpn = ACPNAutoencoder().to(device)
print(f"ACPN model parameters count: {sum(1 for _ in model_acpn.parameters())}")
print("ACPN model parameters:")
for name, param in model_acpn.named_parameters():
print(f" {name}: {param.shape}")
optimizer_acpn = optim.SGD(acpn_params, lr=lr)
Logs:
Using device: cuda
ACPN model parameters count: 0
ACPN model parameters:
Number of parameters for ACPN optimizer: 0
Number of parameters for Linear optimizer: 4
Traceback (most recent call last):
File "c:\AIProjects\ACPNC\acpn_autoencoder.py", line 88, in <module>
optimizer_acpn = optim.SGD(acpn_params, lr=lr) # Try SGD instead of Adam
File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\optim\sgd.py", line 63, in __init__
super().__init__(params, defaults)
File "C:\Users\OliverWH\anaconda3\envs\aiwork\lib\site-packages\torch\optim\optimizer.py", line 372, in __init__
raise ValueError("optimizer got an empty parameter list")
ValueError: optimizer got an empty parameter list
Can anyone help me understand the correct approach to do this or point me to working implementation on GitHub using Extensions like I am?