Consider a nxm matrix A and a n-dimensional vector x whose entry is an integer from 0 to m-1. I’d like to mask A in a way such that A[i][j] is masked iff j > x[i]. I’m seeking for a method that utilizes GPU and PyTorch. I don’t think torch.gather is applicable here.
mask = (A < torch.arange(0, A.size(1), device=A.device, dtype=A.dtype).unsqueeze(0)) masked_A = torch.where(mask, A, torch.zeros(1,1, dtype=A.dtype, device=A.device))
- the arange will give you a vector of
jand unsqueezing makes is broadcasteable to
- the comparison gives a matrix of the same shape as A with the desired mask
torch.whereselects the masked inputs and zero otherwise.
The reason to use
where instead of
mask.float() * A is to also mask away NaNs.
[Edit: Depending on whether masked means “stays if the condition is set” or “is set to zero”, you might use
<= in place of
Thank you for your answer. My question may not have been clear, so I’d like to clarify that. I said about the vector x “entry is an integer from 0 to m-1,” but that doesn’t necessarily mean x[i]=i. I meant that each entry of x can be any integer from 0 to m-1 (multiple entries may have the same value). Since your answer is not dependent on such x, I believe my question was misleading. I’m sorry.
arange produces the
j, it’s not depending on
i, but broadcast in the first dimension.
Oh, I think I understand what you meant. I guess you mistakenly put A instead of x in the first line, and I believe it should be like the following:
mask = (x < torch.arange(0, A.size(1), device=A.device, dtype=A.dtype).unsqueeze(0))
Anyway, that solved my problem. Thank you very much!
Ah right, sorry. Glad you figured it out!