How can I replace part of a matrix with values of another matrix

Consider following matrixes

>>> a = torch.Tensor([[1,2,3],[4,5,6], [7,8,9]])
>>> a
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
>>> b = torch.tensor([[1,1],[1,1]])
>>> b
tensor([[1, 1],
        [1, 1]])

I want to replace 4 elements in a with b where their indices are specified in X = [0,2] and Y = [0,2]

To have:

>>>a
tensor([[1., 2., 1.],
        [4., 5., 6.],
        [1., 8., 1.]])

Direct indexing should work:

a = torch.tensor([[1,2,3],[4,5,6], [7,8,9]])
b = torch.tensor([[1,1],[1,1]])

a[[0, 2], [[0], [2]]] = b

print(a)
# tensor([[1, 2, 1],
#         [4, 5, 6],
#         [1, 8, 1]])
3 Likes

Thanks, I am novice to these stuff. Could you generalize your solution? Suppose i have multiple indices in X and multiple indices in Y. Based on your answer I guess I should do:

a[ X, Y.T] = b

but it doesn’t seem to work?

UPDATE: the following seems to work if you mean the same:

a[X.reshape(-1,1), Y] = b

In fact I want to do these in batch, where things get more complicated. I mean all a, b and X and Y are in batch. Suppose batch is 2 and a is:

>>> a
tensor([[[1, 1, 3],
         [1, 1, 6],
         [1, 1, 9]],

        [[1, 1, 3],
         [1, 1, 6],
         [1, 1, 9]]])

and these are bs

>>> b
tensor([[[0, 0],
         [0, 0],
         [0, 0]],

        [[1, 1],
         [1, 1],
         [1, 1]]])

and these are X and Y:

>>> X
tensor([[0, 1, 2],
        [1, 2, 0]])

>>> Y
tensor([[1, 2],
        [0, 1]])

This is part of my efforts:

a[:, X.reshape(2,-1,1), Y.reshape(2,1,-1)] = b

Or maybe I have to use loop for batch size like this, right?:


>>> for i in range(2):
...    a[i, X[i].reshape(-1,1), Y[i]] = b[i]

The right approach depends on your desired output and a[:, X.reshape(2,-1,1), Y.reshape(2,1,-1)] = b as well as the loop will return different outputs so could you check which one is the desired one?

The loop returns desired ones where each a (in different batch) has its own data. However, the first method seems to update all As with same (or maybe the last) one. How could I make it returns the result of loop?

However, I don’t need to have a in batch and I could update a single a and keep the final result (some elements are overwritten). Maybe in this case both methods returns same result.

If the loop creates the desired output, this approach should work:

a = torch.tensor([[[1, 1, 3],
                   [1, 1, 6],
                   [1, 1, 9]],

                  [[1, 1, 3],
                   [1, 1, 6],
                   [1, 1, 9]]])

b = torch.arange(2*3*2).view(2, 3, 2)

X = torch.tensor([[0, 1, 2],
                  [1, 2, 0]])

Y = torch.tensor([[1, 2],
                  [0, 1]])

reference = a.clone()
out = a.clone()

for i in range(2):
    reference[i, X[i].reshape(-1,1), Y[i]] = b[i]

out[torch.arange(out.size(0))[:, None, None], X.view(X.size(0), -1, 1), Y.view(Y.size(0), 1, -1)] = b

print((reference - out).abs().max())
# tensor(0)

Note that I changed b to avoid duplicated entries in order to have a better comparison.

1 Like

Thank you very much! just a final question if you have time. Do these approaches differ in performance?

Yes, the second approach should be faster than the for loop but the actual performance gain depends on the batch size / number of iterations.
You could profile it via:

def fun1(X, reference, b):
    for i in range(X.size(0)):
        reference[i, X[i].reshape(-1,1), Y[i]] = b[i]
    return reference

def fun2(X, out, b):
    out[torch.arange(out.size(0))[:, None, None], X.view(X.size(0), -1, 1), Y.view(Y.size(0), 1, -1)] = b
    return out


for batch_size in torch.logspace(1, 4, 4):
    batch_size = batch_size.int().item()
    print("batch_size ", batch_size)
    a = torch.randn(batch_size, 3, 3)
    b = torch.randn(batch_size, 3, 2)
    X = torch.randint(0, 3, (batch_size, 3))
    Y = torch.randint(0, 3, (batch_size, 2))
    
    reference = a.clone()
    out = a.clone()
    
    reference = fun1(X, reference, b)
    out = fun2(X, out, b)
    print((reference - out).abs().max())
    
    %timeit fun1(X, reference, b)
    %timeit fun2(X, out, b)

which shows:

batch_size  10
tensor(0.)
68.8 µs ± 93.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
12.7 µs ± 386 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
batch_size  100
tensor(0.)
676 µs ± 1.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
20.1 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
batch_size  1000
tensor(0.)
6.86 ms ± 75.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
69.7 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
batch_size  10000
tensor(0.)
70 ms ± 94.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
113 µs ± 6.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

on the CPU (for GPU profiling: don’t forget to add ; torch.cuda.synchronize() into the %timeit command).
As you can see, the more iteration the for loop approach needs, the larger the benefit.

1 Like