Are torch.sort() and torch.take() differentiable?

Hi there! I am customizing my own loss function using torch.sort() and torch.take() functions. However, with the custom loss function, a fully-connected feed-forward neural network (nothing special, with linear, ReLU and batch norm layers) does not learn well on the dataset. Therefore, I am wondering if there exist some issues with my loss function, e.g., the gradients are not back-propagated due to the use of these two functions. I checked the official document and also read some threads on Internet, but I don’t find any answers for that, unfortunately.

By the way, my motivation of defining my loss function is to penalize the violation of monotonicity of two series of data. For example, I got two tensors x, y which contain the time-series of two variables, and I knew that y should increase/decrease as x increases/decreases. Hence, I sorted x (in ascending order) and recorded the index (which was done by torch.sort()), and re-ordered y using that index (which was done by torch.take()). Lastly, I just needed to check the difference between the consecutive pair of data in sorted y from which monotonicity can be checked.

I really appreciate any feedback and thanks in advance for your time!

In terms of torch.sort or torch.take they do have grad_fns, except for the indices returned via torch.sort

>>> x = torch.randn(10,2, requires_grad=True)
>>> torch.take(x, torch.LongTensor([1]))
tensor([-0.2461], grad_fn=<TakeBackward0>)
>>> torch.sort(x)[0]
tensor([[-1.2673, -0.2461],
        [-0.6003,  0.2518],
        [-1.1847,  0.2308],
        [ 0.1341,  2.5518],
        [-3.3000,  0.4909],
        [-1.2083,  1.7219],
        [-0.1799,  1.0432],
        [-1.3900,  0.2327],
        [-0.3572,  0.0238],
        [-0.7487,  0.4416]], grad_fn=<SortBackward0>)
>>> torch.sort(x)[1]
tensor([[0, 1],
        [1, 0],
        [0, 1],
        [1, 0],
        [1, 0],
        [1, 0],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]])

Also, one thing to check within your loss function is to check if any intermediate gradients are 0 because gradients are calculated via the chain rule, and so if one intermediate gradient is 0 then all of them will be 0.

Also, check to make sure you’re not calling .detach() or torch.no_grad() within any part of your code, and make sure your data has requires_grad=True otherwise your gradient will be 0 by definition.

Hi wgu!

If I understand your use case correctly, when you have
x[i + 1] > x[i], you would also like to have y[i + 1] > y[i],
and you would like to impose a penalty if y[i + 1] < y[i]. (And,
analogously, switching < for >.)

But you don’t care how much this joint monotonicity is violated. If
y[i + 1] is smaller than y[i] by just a little bit, you want to penalize
it by essentially the same amount as when y[i + 1] is smaller by a

Also, I assume that you want the gradient of your “non-monotonicity
loss” with respect to the values of x and y. For example, perhaps
x and y are generated by a neural network, and you want to train
that network to only generate such data series that are monotone
with respect to one another. Thus, you would want to backpropagate
the gradient of your loss through x and y to the network parameters
that generated them.

You don’t need (nor I think want) to use sorting for this.

torch.tanh() is a differentiable (“soft”) approximation to the
sign() function. So you could use (in vector form):

loss = (1.0 - (torch.tanh (alpha * (x[1:] - x[:-1])) * torch.tanh (alpha * (y[1:] - y[:1])))).mean()

where alpha can be used to sharpen (or smooth out) the abruptness
with which tanh() approximates sign().


K. Frank

1 Like

Thanks AlphaBetaGamma96! That was a quick check on the grad_fn!

One more question: if I use the indices returned from torch.sort() and feed it to torch.take(), will the gradients be backpropagated? (I think yes, and I posted my test code as follows.)

Test code:

import torch
from torch import nn
x = torch.tensor([1, 4, 7, 3, 9, 0], dtype=torch.float, requires_grad=True)
x_sorted = torch.sort(x)[0]
x_id = torch.sort(x)[1]
print(x_sorted, x_id)
y = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.float, requires_grad=True)
y_sorted = torch.take(y, x_id) # Here is my question: Will the gradients be backpropagated?
print(y, y_sorted)

loss_mse = nn.MSELoss()
loss_xy = loss_mse(x, y)
loss_ys = y_sorted.mean()
loss_xs =  x_sorted.mean()

print(f"Gradients: {loss_ys.item()}, {loss_xs.item()}, {loss_xy.item()}")

loss_xid = x_id.float().mean()

which returns

tensor([0., 1., 3., 4., 7., 9.], grad_fn=<SortBackward>) tensor([5, 0, 3, 1, 2, 4])
tensor([0., 1., 2., 3., 4., 5.], requires_grad=True) tensor([5., 0., 3., 1., 2., 4.], grad_fn=<TakeBackward>)
Gradients: 2.5, 4.0, 14.166666984558105

RuntimeError                              Traceback (most recent call last)
<ipython-input-30-1004213fe0c3> in <module>
     21 loss_xid = x_id.float().mean()
---> 22 loss_xid.backward()

~/miniconda3/envs/base-tornado5/lib/python3.8/site-packages/torch/ in backward(self, gradient, retain_graph, create_graph, inputs)
    243                 create_graph=create_graph,
    244                 inputs=inputs)
--> 245         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    247     def register_hook(self, hook):

~/miniconda3/envs/base-tornado5/lib/python3.8/site-packages/torch/autograd/ in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    143         retain_graph = create_graph
--> 145     Variable._execution_engine.run_backward(
    146         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    147         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

It can be seen that gradients are calculated for loss_ys (which does not have any actual meaning but only takes average on y_sorted for testing purposes). And for loss_xid, error popped up due to the lack of grad_fn for torch.sort(x)[1] just like AlphaBetaGamma96 pointed out. So my answer to my question is yes: torch.take(y, x_id) supports gradient backpropagation with x_id = torch.sort(x)[1] and y being a tensor that has requires_grad=True.

Let me know if I am making mistakes here. Thanks!

Thanks a million, KFrank! That’s absolutely what I want for calculating the loss. I’m using it to train the network and it seems working better than my previous implementation of loss function. Very beautiful solution!