I want to swap elements in a 3D tensor which contains scattered non-zeros elements in shape (batch_size, seq_len, elem_dim), while keeping the gradient tracked. Given an example,
inputs = [[1 0 2]
[0 3 0]
[0 0 4]] # each bold integer represents a vector
I come from TensorFlow knowing that this is somehow inconvenient and ineffective to do, but TensorFlow allows gradient tracking with scatter_fn and gather_fn.
If some swap function is provided in newest version of PyTorch, I think I do this.
Sorry for the late reply. I tried scatter and confirmed that it is the function I was looking for.
I was misled by the scatter_nd in TensorFlow which has a more scattered capacity.
However, torch.scatter is just simple and handy with the gradients tracked.