Indexing with sampled categorical

I might be doing something wrong, but when I try to index a tensor of logits with a sample from a Categorical of those logits, I expect sample.shape == logits[sample].shape[:-1], but this is not the case.

dist = Categorical(logits=logits)
sample = dist.sample()
indexed_logits = logits[sample]
assert indexed_logits.shape == logits.shape[:-1]

Basically, I want to treat the categorical sample as indices.