# How to implement a piecewise function that supports autograd?

I want to implement a piecewise function on a tensor that has the following logic for each element in the tensor.

``````def f(x):
if x==0.5:
return 1
else:
``````

Clearly, torch.atanh(x)/(x-0.5) is not defined when x==0.5.

The above function is for one single element, and so I can easily make two branches. But I want to know how to implement this logic on a multi-element tensor X, such that a tensor of the same shape as the input is returned and each element has the value f(x), where x is the corresponding element in X.

I could use the torch.where() function, but then since both branches of torch.where() are computed, I may not avoid divide by zero errors when the elelemtn is 0.5.

`torch.where` should be okay because if you divide by 0, shouldn’t you be getting inf/nan?

This might be a potential solution to your question,

``````def fnew(x):
return (~(x==0.5))*torch.nan_to_num(torch.atanh(x)/(x-0.5)) + (x==0.5)*1
``````

This would be a branchless example of your if-else statement above. If `x != 0.5` (as defined by `(~(x==0.5))` ) the left hand of the is multipled by 0 and it returns 1. If `x == 0.5`, the opposite behaviour happens and it returns `atanh(x)/(x-0.5)`.

I’ve only placed a `torch.nan_to_num` in there so to convert any `nans` to 0. If your inputs are only within the range `(-1 ,1)` you can remove the `torch.nan_to_num` as it’ll serve no purpose.