How to implement a piecewise function that supports autograd?

I want to implement a piecewise function on a tensor that has the following logic for each element in the tensor.

def f(x):
    if x==0.5:
        return 1
    else:
        return torch.atanh(x)/(x-0.5)

Clearly, torch.atanh(x)/(x-0.5) is not defined when x==0.5.

The above function is for one single element, and so I can easily make two branches. But I want to know how to implement this logic on a multi-element tensor X, such that a tensor of the same shape as the input is returned and each element has the value f(x), where x is the corresponding element in X.

I could use the torch.where() function, but then since both branches of torch.where() are computed, I may not avoid divide by zero errors when the elelemtn is 0.5.

torch.where should be okay because if you divide by 0, shouldn’t you be getting inf/nan?

This might be a potential solution to your question,

def fnew(x):
  return (~(x==0.5))*torch.nan_to_num(torch.atanh(x)/(x-0.5)) + (x==0.5)*1

This would be a branchless example of your if-else statement above. If x != 0.5 (as defined by (~(x==0.5)) ) the left hand of the is multipled by 0 and it returns 1. If x == 0.5, the opposite behaviour happens and it returns atanh(x)/(x-0.5).

I’ve only placed a torch.nan_to_num in there so to convert any nans to 0. If your inputs are only within the range (-1 ,1) you can remove the torch.nan_to_num as it’ll serve no purpose.