Mixing bool indexing with normal indexes. Bug?

I want to do something like this:

  import torch
  
  x = torch.randn((5, 4, 3))
  mask = torch.randn((5, 4)) < 0
  
  # This is ok
  print(x[mask, :].shape)
  
  # This is not, why?
  print(x[mask, 0].shape)
  
  # I like to set a value like this.
  x[mask, 0] = 2

But I get an error for the line x[mask, 0]: “IndexError: The shape of the mask [5, 4] at index 1 does not match the shape of the indexed tensor [5, 3] at index 1”

What is going on here. Is it a bug or feature and is there some way to accomplish x[mask, 0] = 2 without an intermediate tensor?