Piecewise constraint in Pytorch

Let’s say I am predicting a quaternion. How do I implement this piecewise constraint that is conditioned on the scalar/real “w” (for quaternion = [x,y,z,w]).

In Python:

def positive_real_quat(quat):
    Flip quaternion to positive real hemisphere so there is a single quaternion describing each rotation

    :param quat: [x,y,z,w]
    x, y, z, w =quat
    if w >= 0:
        return quat
        return -quat

But I need the above code to be differentiable as the final function in my neural network. How do I do it?

Since you are only multiplying with -1 for a negative w, your function should not break the computation graph.
I’ve changed the input to a tensor and it seems to work fine:

def positive_real_quat(quat):
    x, y, z, w = quat.split(1, 1)
    if w >= 0:
        return quat
        return -quat
# Positive
x1 = torch.tensor([[1., 1., 1., 1.]], requires_grad=True)
out1 = positive_real_quat(x1)
# Check grad
> tensor([[0.2500, 0.2500, 0.2500, 0.2500]])

# Negative
x2 = torch.tensor([[1., 1., 1., -1.]], requires_grad=True)
out2 = positive_real_quat(x2)
# Check grad
> tensor([[-0.2500, -0.2500, -0.2500, -0.2500]])
1 Like

It seems to be saying that you can always “differentiate” through if else conditionals because you only pass the gradient through the conditional that is selected… kind of like how max is differentiable because you only pass the gradient through the maximum value…

Are there any cases where program logic is not differentiable?

Note that torch.max and your approach are a bit different.
While torch.max works on a tensor and will pass the gradient to the max value (as you’ve explained), your approach would create different computation graphs.
Since PyTorch creates the computation graph (which is used to calculate the gradients in the backward pass) dynamically, you can use plain Python conditions, loops, etc. in your code.

I guess what I’m looking for is some rigorous argument that gradient descent will minimize a computation graph that contains Python conditions and logic.

Hi, I am also very interested in this question. We are trying to implement a “custom” (to our problem) piecewise function using simple if/else logic. Simplified example; we create torch parameters (a,b,c and d) and say that e.g. “if input < 5: value = ainput + b, else: value = cinput + d”.

The backpropagation and optimization seems to only give gradients for some of the values though, although we know that some of the inputs are in the different “ranges” (we have inputs both over and under 5, following the example above).

Will Pytorch accept such a function and the backpropagation + optimization (SGD) algorithm actually work for our function?