Boolean matrix indexing

I ran into a non-numpy-like behavior while trying to use a 2d boolean tensor as a mask. This code reproduces it:

a = torch.zeros((3, 3, 2))
mask = torch.zeros((3, 3), dtype=bool)
mask[(1, 2, 2), (0, 0, 1)] = 1
print(mask)
a[mask, 0] = 1

What I get is:

IndexError: The shape of the mask [3, 3] at index 1does not match the shape of the indexed tensor [3, 2] at index 1

Basically, I cannot select which elements in my third dimension should be masked. The numpy version works fine on this one.