Consider a nxm matrix A and a n-dimensional vector x whose entry is an integer from 0 to m-1. I’d like to mask A in a way such that A[i][j] is masked iff j > x[i]. I’m seeking for a method that utilizes GPU and PyTorch. I don’t think torch.gather is applicable here.
How about
mask = (A < torch.arange(0, A.size(1), device=A.device, dtype=A.dtype).unsqueeze(0))
masked_A = torch.where(mask, A, torch.zeros(1,1, dtype=A.dtype, device=A.device))
where
- the arange will give you a vector of
j
and unsqueezing makes is broadcasteable toA
s shape, - the comparison gives a matrix of the same shape as A with the desired mask
-
torch.where
selects the masked inputs and zero otherwise.
The reason to use where
instead of mask.float() * A
is to also mask away NaNs.
Best regards
Thomas
[Edit: Depending on whether masked means “stays if the condition is set” or “is set to zero”, you might use <=
in place of >
.]
Thank you for your answer. My question may not have been clear, so I’d like to clarify that. I said about the vector x “entry is an integer from 0 to m-1,” but that doesn’t necessarily mean x[i]=i. I meant that each entry of x can be any integer from 0 to m-1 (multiple entries may have the same value). Since your answer is not dependent on such x, I believe my question was misleading. I’m sorry.
The arange
produces the j
, it’s not depending on i
, but broadcast in the first dimension.
Oh, I think I understand what you meant. I guess you mistakenly put A instead of x in the first line, and I believe it should be like the following:
mask = (x < torch.arange(0, A.size(1), device=A.device, dtype=A.dtype).unsqueeze(0))
Anyway, that solved my problem. Thank you very much!
Ah right, sorry. Glad you figured it out!