For example, I have an 4*4 2D tensor A.

A = torch.tensor(

[[1,2,5,8],

[5,8,4,6],

[9,1,6,7],

[1,4,3,1]

])

And a 4*4 tensor M whose elements are all 1 by default.

Given a tensor called left_mask = torch.tensor([2,0,1,3]), I want to let M[i][j] = 0 if (j <= left_mask[i] and A[i][j] != 1).

In this case, the resulting M should be:

M = torch.tensor(

[[1,0,0,1],

[0,1,1,1],

[0,1,1,1],

[1,0,0,1]

])

In other words, I want to mask some values of M (with 0) which are at the left of the idx specified by left_mask in its row but with an exception that the corresponding position of A is not 1.

I know pytorch has some functions about tensor mask. However, I donâ€™t know how to write this function. Can anyone help me?

Thank you very much.