Using torch::normal with the C++ frontend

I am probably missing something, but I am having trouble using the torch::normal function in the C++ API. Here’s the minimal example:

#include <torch/torch.h>
#include <iostream>

int main() {
  torch::Tensor tensor = torch::rand({2, 3});
  auto x = torch::normal(tensor,1);
  std::cout << x << std::endl;
}

The compilation of this example fails with:

/Users/dfalbel/Documents/testtorch/example-app.cpp:6:12: error: too few arguments to function call, expected at least 3, have 2; did you mean 'at::normal'?
  auto x = torch::normal(tensor,1);
           ^~~~~
           at::normal
/Users/dfalbel/libtorch/include/ATen/Functions.h:13495:22: note: 'at::normal' declared here
static inline Tensor normal(const Tensor & mean, double std, Generator * generator) {
                     ^
1 error generated.

This is can be fixed by using at namespace instead of torch. I could do that, but it seems to me that all signatures of at::normal (1, 2, 3, 4) should be exported in the torch namespace too? Does this make sense?

1 Like

It seems that the reason is that when we autogenerate wrapers for some ATen functions here

we are generating a single torch::normal definition and not all it’s overloaded defs.

And this seems to happen because of this line which excludes from the variable_factories.h all functions that don’t have TensorOptions as an argument, thus excluding the other signatures.

Adding the following to my code fixes the compiler error, but i am pretty sure it’s wrong :no_entry_sign: when in autograd contexts:

namespace torch {

torch::Tensor normal (const torch::Tensor &mean, double std = 1, torch::Generator *generator = nullptr) {
  return at::normal(mean, std, generator);  
}

torch::Tensor normal (double mean, const torch::Tensor &std, torch::Generator *generator = nullptr) {
  return at::normal(mean, std, generator);  
}

torch::Tensor normal (const torch::Tensor &mean, const torch::Tensor &std, torch::Generator *generator = nullptr) {
  return at::normal(mean, std, generator);  
}

} // namespace torch