How to write custom module with segmented function?

I want to use pytorch as an simple function fitter, and the logic of function looks like:

if (x < param_seg1) return param_a * x;
else if (x < param_seg2) return param_b * x *x;
else return param_c * x * x * x;

and all these param_xxxs need to be learned.

As I want to utilize pytorch’s autograd feature, all arithmetic operations shall be taken with pytorch’s builtin. So how can I cope with the segment points? Is there anything like step function in GPU shader that works as replacement of if switches?

You should be able to directly execute the code without any modifications to the if conditions etc.
PyTorch uses the “eager” mode by default, i.e. it’ll execute each line of code once it reaches it and is able to understand Python conditions, loops, etc.