Torch multinomial in generate function

Hi, I came across a generate code in the language model that is as follows :-

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            id_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, id_next), dim = 1)
        return idx

I have two questions regarding it

Q1) Why is torch.multinomial used instead of torch.argmax for selecting the index of the next token to generate?

Q2) Why are we using torch.multinomial only why not some other probability distribution?

2 Likes

Hi Mohit!

In regard to your questions about multinomial() (but not trying to comment
on the larger use case that generate() is presumably a part):

Let’s say that your two largest probs are rather close together (for example,
0.25 and 0.26). Using argmax() would always give you the index of 0.26,
ignoring, in a sense, that 0.25 is almost the same. On the other hand, using
multinomial() will give you the index of 0.26 26% of the time and the index
of 0.25 25% of the time, respecting the fact that the two values are quite close
to one another. (This may or may not be the behavior you want, but it does
make sense.)

You could use other probability distributions and something else (like using
argmax(), which is the probability distribution that gives you one specific index
100% of the time) might better fit your use case. However, as noted above,
using multinomial() samples index values according to probs, which could
well be what you want.

Best.

K. Frank

1 Like