pytorch lets you call .reinforce()
directly on the output of the random variable. You need to sample the random variable using torch.multinomial
rather than np.random.multinomial
. You can see how this works at What is action.reinforce(r) doing actually? for example.