Hi there,

I have a very specific masking problem on the outputs and I could not use gather.

I have two matrices:
M = [[1 2 3]
[4 5 6]
[7 8 9]]
and
I = [1,2,1]

I need the outputs to be
out = [[1 2 3]
[0 5 0]
[0 0 0]]

Basically, this operation will make the elements in “M” whose indexes larger than the corresponding their indexes in “I” to zero. For example, the elements whose indexes greater than 1 (since the first column in “I” is 1) in the first column of “M” will be zero.

Is there an easy way to do this?

Thanks!

I’m not sure that your description of what you want matches your example.

Could you explain how you get from the inputs:

``````M =
[[1 2 3]
[4 5 6]
[7 8 9]]
and
I = [1,2,1]
``````

``````[[1 2 3]
[0 5 0]
[0 0 0]]
``````

? Should the 5 be there?

Yes, 5 is already in the output. Basically, the operation is using every element in “I” as an index indicator. For each column, every element’s will be made to zero if their index is greater than its corresponding index in “I”. So the elements in the second column of “M” become zero since the second element in “I” is 2. So 8 become 0.

I’m not sure if there’s a fast way to do this, but you could iterate through all of the rows. Here’s an example:

``````M = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
I = torch.LongTensor([0, 1, 0])
out = torch.cat([((x <= x[i]).type_as(x) * x).unsqueeze(0) for x, i in zip(M.t(), I)]).t()
``````

gives: Hi Richard,

Thanks for the solution! I also thought about the cat operation, but I am not sure if there is some build-in function can do this efficiently without for loop since masking should be a very common function.

I also tried pack_padded_sequence and pad_packed_sequence, but they require the index matrix “I” is sorted, which is not flexible enough.

You could also do:

``````mask = M.index_select(0, I).diag().expand(3, 3) >= M
``````

This is a very smart solution!

However, it assumes that my “M” is sorted by column. When “M” is
M = torch.Tensor([[1, 7, 4], [4, 5, 6], [7, 2, 6]])
and I = torch.LongTensor([0, 2, 1])

mask = M.index_select(0, I).diag().expand(3, 3) >= M
the results will be