Scattering of data

I require a functionality that comes close to TensorFlows scatter_nd function, I guess.

Let’s assume that I have an m-x-n matrix A and four row vectors N, S (of size n) and E, E (of size m). I need to create a new matrix B with the four vectors replacing the first/last row/column, e.g.

    | N1    N2     ...  Nn-1     Nn   |
    | W2    A2,2   ...  A2,n-1   E2   |
B = | W3    A3,2   ...  A3,n-1   E3   |
    | ...   ...    ...  ...      ...  |
    | Wm-1  Am-1,2 ...  Am-1,n-1 Em-1 |
    | S1    S2     ...  Sn-1     Sn   |

The vectors are such that N1=W1, Wm=S1, … so it does not matter whether N1 or W1 it put at B1,1, etc.

The above ist just an example, I need the same for 3 and 4 dimensional tensors, then replacing the “boundaries” of the tensors by matrices and 3-d tensors, respectively.

Any help is appreciated.

Hi Matthias!

You may modify A (or a copy of it) by assigning into it using slices:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> m = 5
>>> n = 6
>>>
>>> A = 5 * torch.ones (m, n)
>>>
>>> N = 1 * torch.ones (n)
>>> S = 2 * torch.ones (n)
>>> E = 3 * torch.ones (m)
>>> W = 4 * torch.ones (m)
>>>
>>> A[ 0, :] = N
>>> A[-1, :] = S
>>> A[:, -1] = E
>>> A[:,  0] = W
>>>
>>> A
tensor([[4., 1., 1., 1., 1., 3.],
        [4., 5., 5., 5., 5., 3.],
        [4., 5., 5., 5., 5., 3.],
        [4., 5., 5., 5., 5., 3.],
        [4., 2., 2., 2., 2., 3.]])

Best.

K. Frank

Hi Frank,

thanks. Two questions on this: How do I realise this in C++ (Tensor Indexing API — PyTorch main documentation) and does the above work when I need to compute gradients of the resulting tensor?

Best,
Matthias

Hi Matthias!

Yes, this will work fine for computing gradients both with respect to A and the
N, S, E, W vectors.

Note that, as written, my suggestion modifies A inplace. This can complicate,
but does not invalidate, the use of gradients.

(I don’t know anything about the C++ api.)

Best.

K. Frank

Thanks, Frank.

I tried your suggestion and translated it to C++ according to Tensor Indexing API — PyTorch main documentation. The following does indeed work:

A=.index_put_({ 0, "..."}, N);
A=.index_put_({-1, "..."}, S);
A=.index_put_({"...", -1}, E);
A=.index_put_({"...",  0}, W);