How to fill a matrix with some exception

For example, I have an 4*4 2D tensor A.

A = torch.tensor(
[[1,2,5,8],
[5,8,4,6],
[9,1,6,7],
[1,4,3,1]
])

And a 4*4 tensor M whose elements are all 1 by default.

Given a tensor called left_mask = torch.tensor([2,0,1,3]), I want to let M[i][j] = 0 if (j <= left_mask[i] and A[i][j] != 1).

In this case, the resulting M should be:

M = torch.tensor(
[[1,0,0,1],
[0,1,1,1],
[0,1,1,1],
[1,0,0,1]
])

In other words, I want to mask some values of M (with 0) which are at the left of the idx specified by left_mask in its row but with an exception that the corresponding position of A is not 1.

I know pytorch has some functions about tensor mask. However, I don’t know how to write this function. Can anyone help me?

Thank you very much.

This code should work:

A = torch.tensor(
    [[1,2,5,8],
     [5,8,4,6],
     [9,1,6,7],
     [1,4,3,1]])
M = torch.ones(4, 4)
left_mask = torch.tensor([2, 0, 1, 3])
for idx in range(M.size(0)):
    M[idx, (A[idx, :left_mask[idx]+1]!=1.).nonzero()] = 0.

Thank you very much!