For example, I have an 4*4 2D tensor A.
A = torch.tensor(
And a 4*4 tensor M whose elements are all 1 by default.
Given a tensor called left_mask = torch.tensor([2,0,1,3]), I want to let M[i][j] = 0 if (j <= left_mask[i] and A[i][j] != 1).
In this case, the resulting M should be:
M = torch.tensor(
In other words, I want to mask some values of M (with 0) which are at the left of the idx specified by left_mask in its row but with an exception that the corresponding position of A is not 1.
I know pytorch has some functions about tensor mask. However, I don’t know how to write this function. Can anyone help me?
Thank you very much.