Remove tensors elements by mask

Let’s say I have a tensor of shape a = [Batchsize , 80 , 2048].
and I have another mask tensor of shape b = [Batchsize] containing 0 , 1. I want to remove the element i from tensor a which satisfy b[i] = 1 , and keep the other tensors in the same order.
so far I did it in for loops, but I’m sure there is another fast way to do it. Thanks

You could index a by keeping all samples, which meet the requirement b!=1:

a = torch.randn([10, 80, 2])
b = torch.randint(0, 2, (10,))

res = a[b!=1]

Sorry for bothering you, but I have a similar problem. For example:
a = [Batchsize , 80 , 2048]
b = [Batchsize] containing labels 0, 1,2,3, or 4. For each batch, I want to remove the element ith from the tensor, which satisfies i = b[i] (label of the corresponding batch) and keeps the other tensors in the same order.
I would like to know the effective way to solve this problem. Thank you for your support.