I am looking efficient and PyTorch idiomatic approach to zeroing out values beyond certain indices.
I have found one solution but it seems more complicated then needs to be.
Here is an example:
Given
x: tensor([[[0.7418, 0.3182, 0.4222, 0.0584, 0.0477]],
[[0.4293, 0.3079, 0.2928, 0.8873, 0.9470]],
[[0.6137, 0.3592, 0.2576, 0.8944, 0.4743]],
[[0.6279, 0.2723, 0.2599, 0.6904, 0.9212]],
[[0.2126, 0.2689, 0.2441, 0.8755, 0.4720]]])
x.shape: torch.Size([5, 1, 5])
and
z: tensor([3, 3, 1, 4, 1])
z.shape: torch.Size([5])
Modify x such that values are zeroes beyond indices specified in z
x^: tensor([[[0.7418, 0.3182, 0.4222, 0.0000, 0.0000]],
[[0.4293, 0.3079, 0.2928, 0.0000, 0.0000]],
[[0.6137, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.6279, 0.2723, 0.2599, 0.6904, 0.0000]],
[[0.2126, 0.0000, 0.0000, 0.0000, 0.0000]]])
x^: torch.Size([5, 1, 5])