Runtime error in custom Autograd function


I’m having an issue with a custom autogrid function that implements a sorting operation on an input x.

The operation sorts the rows of X based on taking the sorted indices of an input vector s.

The sorting part of the operation is all fine and working outside of the method, but I am having problems implementing it inside the method.

from torch.autograd import Function

class SortBy(Function): 
  def forward(ctx, x, s):
    _, indices = torch.sort(s)
    results = x[indices, :]

    return result
  def backward(ctx, grad_output, s):
    results, = ctx.saved_tensors

    _, indices = torch.sort(s)
    _, index2 = torch.sort(indices)

    return (results * grad_output)[index2, :]

I’m attempting to test my code using the following:

x = torch.autograd.Variable(torch.randn(4,5),requires_grad=True)

s = torch.randn(4)

Y = SortBy.apply(X, s)

Yhat = torch.sum(Y)


But I get the below error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-283-b5596230715d> in <module>()
      7 Yhat = torch.sum(Y)
----> 9 Yhat.backward(s)

2 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/ in _make_grads(outputs, grads)
     27                                    + str(grad.shape) + " and output["
     28                                    + str(outputs.index(out)) + "] has a shape of "
---> 29                                    + str(out.shape) + ".")
     30             new_grads.append(grad)
     31         elif grad is None:

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([4]) and output[0] has a shape of torch.Size([]).

I don’t understand where the dimension mismatch is coming in. I assume that the grad_output from the torch.sum step returns a single scalar gradient which I then broadcast to my matrix results in order to get the results, then I apply my reverse indexing to ‘reroute’ the gradient to where I want it.


You don’t need Variables, anymore, you can simply do: torch.randn(4, 5, requires_grad=True) :slight_smile:

For your Function. Remember that the backward pass works in reverse from the forward. So the inputs to the backward match the outputs of the forward.
In your case, the backward should only have one input: gradient of the loss wrt result. And should return two things, the gradient wrt x and s.