What is the best way to pass in a tensor that could optionally be None to a function?

I defined a forward function in a module. The forward function checks if the noise input parameter is None or not, if None then create its own noise; if not None, then just use the input noise. However, this could not be achieved using the following code, c10::optional<at::Tensor> could not be converted to torch::Tensor through " mynoise = noise;". What is the best way to pass in a tensor that could optionally be None to a function? Thanks!

torch::Tensor forward(torch::Tensor image, c10::optional <at::Tensor> noise){
torch::Tensor mynoise;
if(noise == None){
auto batch = image.size(0);// , height, width = image.shape
auto height = image.size(2);
auto width = image.size(3);
mynoise = image.new_empty({batch, 1, height, width}).normal
();
}else{
mynoise = noise;
}
return image + weight * mynoise;
}

Does anyone have any suggestions?

The typical thing is to just take a tensor argument (const Tensor& or Tensor). Then test t.defined().
None is translated to an undefined Tensor.

Best regards

Thomas

1 Like

@tom Many thanks! To clarify, in my function forward, i need to check for (noise.defined()) instead of (noise==None) . But if I want to call my function “forward” with noise undefined, how do i do that without setting a default value for noise in the function? I just do " torch::Tensor noise_dummy;" and pass in noise_dummy?

That would work, but the most succinct way of passing “None” would probably be using the empty init list {} .

Best regards

Thomas

1 Like

I am using pytorch 1.6.0 and my cpp code is like

void fit(const at::Tensor & x, const at::Tensor & y, const at::Tensor & w){
   if(w.defined()){
   
   }else{
  
   }
}

and the python binding code (using pybind11) is like

m.def("fit", &fit, py::arg("x"), py::arg("y"), py::arg("w")=py::none())

The compiling is fine. However, I failed to call it in python using fit(x, y) or fit(x,y, None).

The error is
TypeError: fit(): incompatible function arguments.

thanks.

What this is about is that None in Python maps to an undefined torch::Tensor in C++. The the m.def is on the C++ side already, so the default argument should be the default constructed torch::Tensor().

I tried py::arg("w")=at::Tensor(), it still fails and the error is the same.