I might be doing something wrong, but when I try to index a tensor of logits
with a sample
from a Categorical
of those logits, I expect sample.shape == logits[sample].shape[:-1]
, but this is not the case.
dist = Categorical(logits=logits)
sample = dist.sample()
indexed_logits = logits[sample]
assert indexed_logits.shape == logits.shape[:-1]
Basically, I want to treat the categorical sample as indices.