Remove tensors elements by mask

Let’s say I have a tensor of shape a = [Batchsize , 80 , 2048].
and I have another mask tensor of shape b = [Batchsize] containing 0 , 1. I want to remove the element i from tensor a which satisfy b[i] = 1 , and keep the other tensors in the same order.
so far I did it in for loops, but I’m sure there is another fast way to do it. Thanks

You could index a by keeping all samples, which meet the requirement b!=1:

a = torch.randn([10, 80, 2])
b = torch.randint(0, 2, (10,))

res = a[b!=1]