Sampling word from a probability distribution

I’m trying to build a probabilistic model using PyTorch. I want to sample a batch of words from a categorical distribution (like the corresponding steps in LDA). But In my model, the multinomial distribution that is used to generate the topic assignment for each word has 2 batch dimensions, the shape of which is (I, D, K), I and D are batch dimensions and K is the topic-related dimension. Each element on the topic-related dimension represents the count of the corresponding topic. Currently I sample document-word using a loop that iterates through 3 dimensions, but it’s very slow. How can I speed up this step by vectorizing or whatever?

Here is my code. You can run it directly.