I want to implement a piecewise function on a tensor that has the following logic for each element in the tensor.
def f(x):
if x==0.5:
return 1
else:
return torch.atanh(x)/(x-0.5)
Clearly, torch.atanh(x)/(x-0.5) is not defined when x==0.5.
The above function is for one single element, and so I can easily make two branches. But I want to know how to implement this logic on a multi-element tensor X, such that a tensor of the same shape as the input is returned and each element has the value f(x), where x is the corresponding element in X.
I could use the torch.where() function, but then since both branches of torch.where() are computed, I may not avoid divide by zero errors when the elelemtn is 0.5.