I might be doing something wrong, but when I try to index a tensor of
logits with a
sample from a
Categorical of those logits, I expect
sample.shape == logits[sample].shape[:-1], but this is not the case.
dist = Categorical(logits=logits) sample = dist.sample() indexed_logits = logits[sample] assert indexed_logits.shape == logits.shape[:-1]
Basically, I want to treat the categorical sample as indices.