How to implement a piecewise function that supports autograd?

I want to implement a piecewise function on a tensor that has the following logic for each element in the tensor.

def f(x):
    if x==0.5:
        return 1
    else:
        return torch.atanh(x)/(x-0.5)

Clearly, torch.atanh(x)/(x-0.5) is not defined when x==0.5.

The above function is for one single element, and so I can easily make two branches. But I want to know how to implement this logic on a multi-element tensor X, such that a tensor of the same shape as the input is returned and each element has the value f(x), where x is the corresponding element in X.

I could use the torch.where() function, but then since both branches of torch.where() are computed, I may not avoid divide by zero errors when the elelemtn is 0.5.

torch.where should be okay because if you divide by 0, shouldn’t you be getting inf/nan?

This might be a potential solution to your question,

def fnew(x):
  return (~(x==0.5))*torch.nan_to_num(torch.atanh(x)/(x-0.5)) + (x==0.5)*1

This would be a branchless example of your if-else statement above. If x != 0.5 (as defined by (~(x==0.5)) ) the left hand of the is multipled by 0 and it returns 1. If x == 0.5, the opposite behaviour happens and it returns atanh(x)/(x-0.5).

I’ve only placed a torch.nan_to_num in there so to convert any nans to 0. If your inputs are only within the range (-1 ,1) you can remove the torch.nan_to_num as it’ll serve no purpose.

1 Like

is autograd able to handle this piecewise function? Is it differentiating on each sub-interval?

Looks like it’s able to handle:

import torch
def compute_krw(Sw):
return Sw**3

Sw = torch.rand(4096, 50).requires_grad_(True)
krw = compute_krw(Sw)
krw.backward(torch.ones(krw.shape))

print(Sw)
tensor([[0.7071, 0.3658, 0.8739, …, 0.1909, 0.4205, 0.8919],
[0.8442, 0.5447, 0.8630, …, 0.7269, 0.8380, 0.0856],
[0.1522, 0.3359, 0.1544, …, 0.2483, 0.9337, 0.4175],
…,
[0.6864, 0.4651, 0.4591, …, 0.9432, 0.3384, 0.3128],
[0.8458, 0.0655, 0.3948, …, 0.5267, 0.7302, 0.8763],
[0.8550, 0.8541, 0.4046, …, 0.9442, 0.1375, 0.4865]],
requires_grad=True)

print(Sw.grad)
tensor([[ 4.1874e+00, 6.9268e-01, 0.0000e+00, …, 2.8091e-02,
1.0571e+00, 0.0000e+00],
[ 0.0000e+00, 2.1619e+00, 0.0000e+00, …, 4.4784e+00,
0.0000e+00, -3.8395e-03],
[-1.5801e-02, 5.2467e-01, -1.4277e-02, …, 1.6212e-01,
0.0000e+00, 1.0353e+00],
…,
[ 3.8929e+00, 1.4097e+00, 1.3591e+00, …, 0.0000e+00,
5.3783e-01, 4.1034e-01],
[ 0.0000e+00, 2.1532e-02, 8.7634e-01, …, 1.9774e+00,
4.5290e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 9.4362e-01, …, 0.0000e+00,
-2.2675e-02, 1.5962e+00]])

Autograd can handle control flow within models since the autograd graph is recreated after each .backward() call on the loss. The autograd graph can be different during each iteration as a result of control flow in model’s forward methods, but the graph is created as normal containing the backward methods for whichever path of the control flow get executed.

For the sake of performance though, torch.where operations will probably be faster than writing control flow in python.

If the gradient graph is created on every iteration (i.e. on every backward pass of a batch), what happens if individual sample outputs end up in the domain of different sub-functions of the piecewise function? Surely, these samples should have different gradient graphs.