How to get 2d sub tensor with specific value condition from an original tensor in pytorch?

I want to copy a 2-d torch tensor to a destination tensor containing only values until the first occurrence of 202 value and zero for the rest of the items like this:



how can i do it?


I am not sure if it is the most efficient, but the following works:

mask = source_t.eq(202).flip(1).cumsum(dim=1).flip(1)
mask = mask.eq(, 0).unsqueeze(1))
print(mask) # mask of the source values to keep

res = mask * source_t