Thanks AlphaBetaGamma96! That was a quick check on the grad_fn
!
One more question: if I use the indices returned from torch.sort()
and feed it to torch.take()
, will the gradients be backpropagated? (I think yes, and I posted my test code as follows.)
Test code:
import torch
from torch import nn
x = torch.tensor([1, 4, 7, 3, 9, 0], dtype=torch.float, requires_grad=True)
x_sorted = torch.sort(x)[0]
x_id = torch.sort(x)[1]
print(x_sorted, x_id)
y = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.float, requires_grad=True)
y_sorted = torch.take(y, x_id) # Here is my question: Will the gradients be backpropagated?
print(y, y_sorted)
loss_mse = nn.MSELoss()
loss_xy = loss_mse(x, y)
loss_ys = y_sorted.mean()
loss_xs = x_sorted.mean()
loss_ys.backward()
loss_xs.backward()
loss_xy.backward()
print(f"Gradients: {loss_ys.item()}, {loss_xs.item()}, {loss_xy.item()}")
loss_xid = x_id.float().mean()
loss_xid.backward()
which returns
tensor([0., 1., 3., 4., 7., 9.], grad_fn=<SortBackward>) tensor([5, 0, 3, 1, 2, 4])
tensor([0., 1., 2., 3., 4., 5.], requires_grad=True) tensor([5., 0., 3., 1., 2., 4.], grad_fn=<TakeBackward>)
Gradients: 2.5, 4.0, 14.166666984558105
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-30-1004213fe0c3> in <module>
20
21 loss_xid = x_id.float().mean()
---> 22 loss_xid.backward()
~/miniconda3/envs/base-tornado5/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
243 create_graph=create_graph,
244 inputs=inputs)
--> 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
246
247 def register_hook(self, hook):
~/miniconda3/envs/base-tornado5/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
143 retain_graph = create_graph
144
--> 145 Variable._execution_engine.run_backward(
146 tensors, grad_tensors_, retain_graph, create_graph, inputs,
147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
It can be seen that gradients are calculated for loss_ys
(which does not have any actual meaning but only takes average on y_sorted
for testing purposes). And for loss_xid
, error popped up due to the lack of grad_fn
for torch.sort(x)[1]
just like AlphaBetaGamma96 pointed out. So my answer to my question is yes: torch.take(y, x_id)
supports gradient backpropagation with x_id = torch.sort(x)[1]
and y
being a tensor that has requires_grad=True
.
Let me know if I am making mistakes here. Thanks!