How to quickly set multiple indicated parts of a tensor to zero

I have a tensor A with shape of (b, 1) and a tensor B with shape of (b, h, w). The question is how could I perform this process quickly by using more efficient codes?

for i in range(b):
    B[i, int(A[i]):] = 0

That is, I want to set part of B (a bunch of matrices) to zero, and the beginning row indices are in A.

This code should work:

b, h, w, = 3, 4, 4

A = torch.randint(0, h, (b, 1))
B = torch.randn(b, h, w)
C = B.clone()

idx = torch.zeros(b, h)
idx[torch.arange(b), A.squeeze()] = 1.
idx = idx.cumsum(1).bool()
B[idx] = 0.

for i in range(b):
    C[i, int(A[i]):] = 0

print((B == C).all())
> tensor(True)