How to quickly make a mask tensor using index tensor?

Hi, I used pytorch 1.1 and I have a question about making mask tensor using index tensor.

I have a simple example for understanding.
For example, I have index tensor (size is [2, 4]) which is look like below:

tensor([[0, 0, 2, 4], [0, 0, 0, 3]])

and finally, I want to get mask tensor (size is [2, 4, 6]) like this:

tensor([[[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.]],
[[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]]])

I implemented it using a for loop statement twice like below, but if the tensor size is large (for example, mask size is [32, 8732, 21]), it is too slow to use for training.

Blockquote
for i in range(mask.size(0)):
for j in range(mask.size(1)):
mask[i,j,b[i,j]] = 1

Is there any way to implement the action quickly without using the for loop?

Thank you in advance.

You could use scatter_ as shown here:

# setup
index = torch.tensor([[0, 0, 2, 4], [0, 0, 0, 3]])
mask = torch.zeros(2, 4, 6)

# loop
for i in range(mask.size(0)):
    for j in range(mask.size(1)):
        mask[i,j,index[i,j]] = 1

# scatter
mask2 = torch.zeros(2, 4, 6)
mask2.scatter_(2, index.unsqueeze(2), 1)

print((mask == mask2).all())
> tensor(True)

It works!!
Thank you for your answer!