I want to use pytorch as an simple function fitter, and the logic of function looks like:
if (x < param_seg1) return param_a * x;
else if (x < param_seg2) return param_b * x *x;
else return param_c * x * x * x;
and all these param_xxx
s need to be learned.
As I want to utilize pytorch’s autograd feature, all arithmetic operations shall be taken with pytorch’s builtin. So how can I cope with the segment points? Is there anything like step
function in GPU shader that works as replacement of if
switches?