How to swap tensor elements and keep gradient

Hi, I am new to PyTorch.

I want to swap elements in a 3D tensor which contains scattered non-zeros elements in shape (batch_size, seq_len, elem_dim), while keeping the gradient tracked. Given an example,

inputs = [[1 0 2]
[0 3 0]
[0 0 4]] # each bold integer represents a vector

outputs = func(inputs)
assert outputs == [[1 2] [3 0] [4 0]]
outputs.sum().backward # raise no errors

how to implement this function?

I come from TensorFlow knowing that this is somehow inconvenient and ineffective to do, but TensorFlow allows gradient tracking with scatter_fn and gather_fn.

If some swap function is provided in newest version of PyTorch, I think I do this.

I would say you can do that by modifiying tensor.data
like:

import torch

a=torch.rand(10).requires_grad_(True)
dummy = torch.Tensor(list(range(10)))
optim = torch.optim.SGD([a],lr=1)

optim.zero_grad()

b=(a*a).mean()
print('FAKE A')
print(b.grad_fn)
a.data=dummy
print(b.grad_fn)
b.backward()

print(a.grad)
print('REAL A')
optim.zero_grad()
b=(a*a).mean()
b.backward()
print(a.grad)
FAKE A
<MeanBackward0 object at 0x7f5d5c729710>
<MeanBackward0 object at 0x7f5d5c729710>
tensor([0.1948, 0.1893, 0.0677, 0.0387, 0.0062, 0.1082, 0.1482, 0.1386, 0.0594,
        0.0293])
REAL A
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 1.2000, 1.4000, 1.6000,
        1.8000])

Sorry for the late reply. I tried scatter and confirmed that it is the function I was looking for.
I was misled by the scatter_nd in TensorFlow which has a more scattered capacity.
However, torch.scatter is just simple and handy with the gradients tracked.

Thank you!