For a matrix A and a vector x, how to mask A such that A[i][j] is masked iff j > x[i]?

Consider a nxm matrix A and a n-dimensional vector x whose entry is an integer from 0 to m-1. I’d like to mask A in a way such that A[i][j] is masked iff j > x[i]. I’m seeking for a method that utilizes GPU and PyTorch. I don’t think torch.gather is applicable here.

How about

mask = (A < torch.arange(0, A.size(1), device=A.device, dtype=A.dtype).unsqueeze(0))
masked_A = torch.where(mask, A, torch.zeros(1,1, dtype=A.dtype, device=A.device))

where

  • the arange will give you a vector of j and unsqueezing makes is broadcasteable to As shape,
  • the comparison gives a matrix of the same shape as A with the desired mask
  • torch.where selects the masked inputs and zero otherwise.

The reason to use where instead of mask.float() * A is to also mask away NaNs.

Best regards

Thomas

[Edit: Depending on whether masked means “stays if the condition is set” or “is set to zero”, you might use <= in place of >.]

1 Like

Thank you for your answer. My question may not have been clear, so I’d like to clarify that. I said about the vector x “entry is an integer from 0 to m-1,” but that doesn’t necessarily mean x[i]=i. I meant that each entry of x can be any integer from 0 to m-1 (multiple entries may have the same value). Since your answer is not dependent on such x, I believe my question was misleading. I’m sorry.

The arange produces the j, it’s not depending on i, but broadcast in the first dimension.

1 Like

Oh, I think I understand what you meant. I guess you mistakenly put A instead of x in the first line, and I believe it should be like the following:
mask = (x < torch.arange(0, A.size(1), device=A.device, dtype=A.dtype).unsqueeze(0))
Anyway, that solved my problem. Thank you very much!

Ah right, sorry. Glad you figured it out!