Filling torch tensor with zeros after certain index

I am looking efficient and PyTorch idiomatic approach to zeroing out values beyond certain indices.

I have found one solution but it seems more complicated then needs to be.

Here is an example:

Given

x: tensor([[[0.7418, 0.3182, 0.4222, 0.0584, 0.0477]],

        [[0.4293, 0.3079, 0.2928, 0.8873, 0.9470]],

        [[0.6137, 0.3592, 0.2576, 0.8944, 0.4743]],

        [[0.6279, 0.2723, 0.2599, 0.6904, 0.9212]],

        [[0.2126, 0.2689, 0.2441, 0.8755, 0.4720]]])

x.shape: torch.Size([5, 1, 5])

and

z: tensor([3, 3, 1, 4, 1])

z.shape: torch.Size([5])

Modify x such that values are zeroes beyond indices specified in z

x^: tensor([[[0.7418, 0.3182, 0.4222, 0.0000, 0.0000]],

        [[0.4293, 0.3079, 0.2928, 0.0000, 0.0000]],

        [[0.6137, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.6279, 0.2723, 0.2599, 0.6904, 0.0000]],

        [[0.2126, 0.0000, 0.0000, 0.0000, 0.0000]]])

x^: torch.Size([5, 1, 5])

The following is the same idea, but with less code:

import torch
x = torch.tensor([[[0.7418, 0.3182, 0.4222, 0.0584, 0.0477]],
        [[0.4293, 0.3079, 0.2928, 0.8873, 0.9470]],
        [[0.6137, 0.3592, 0.2576, 0.8944, 0.4743]],
        [[0.6279, 0.2723, 0.2599, 0.6904, 0.9212]],
        [[0.2126, 0.2689, 0.2441, 0.8755, 0.4720]]])

z = torch.tensor([3, 3, 1, 4, 1])

# one line trick
(torch.arange(x.size(2)) < z[..., None]).unsqueeze(1) * x

This gives the desired result.

1 Like

Thank you, Levi. This checks out and works.

Here is a slight modification to specify device and improve consistency:

(torch.arange(x.size(2), device=z.device)<z.unsqueeze(1)).unsqueeze(1)*x

The solution works beautifully but if you know a way to achieve this without creating a new tensor (arange), it would be even better.

The reason I needed it is because I am trying to use pytorch/glow project but the do not support dynamically creating new tensors (see glow/issues/3932)