Using .index_put() with a 3D tensor

if I have a 2D tensor like this one:

>>> torch.ones((2, 4))
[[1, 1, 1, 1],
 [1, 1, 1, 1]]

and want to fill two positions per row with 0, to get:

[[1, 0, 1, 0],
 [0, 1, 1, 0]]

I can do:

torch.ones((2, 4)).index_put((torch.arange(2).unsqueeze(1), torch.LongTensor([[1,3], [0,3]])), torch.Tensor([0]))

What about a 3D tensor? Let’s say I want to fill in a torch.ones(2, 3, 4) tensor with some zeros, to get:

tensor([[[1., 0., 1., 0.],
         [0., 1., 1., 0.],
         [1., 0., 0., 1.]],

        [[0., 1., 0., 1.],
         [1., 0., 0., 1.],
         [1., 1., 0., 0.]]])

if I have the zero-indices stored as:

torch.LongTensor([[[1,3],
                   [0,3],
                   [1,2]],

                  [[0,2],
                   [1,2],
                   [2,3]]])

is there a way to use these indices, to tell .index_put() where to place the zeros?

Hi,

For any number of dimensions, you can use scatter to achieve this:

import torch

ind= torch.tensor([[[1,3],
                   [0,3],
                   [1,2]],

                  [[0,2],
                   [1,2],
                   [2,3]]])

base = torch.ones(2, 3, 4)

base.scatter_(2, ind, 0)

print(base)

Hi Alban!

That’s exactly what I needed, thanks a lot

1 Like