I’m trying to use a torch.distributions.multinomial.Multinomial distribution for a reinforcement learning task. Essentially, my network outputs probabilities over actions, which I want to feed into a Multinomial and then sample from.
My network output looks something like (locations x objects), where each of 5 different object can be placed in 1 of 10 different locations. This is all fine and good.
The problem I have is that in different rollouts, I might have a variable number of each object. So in addition to my (10 x 5) probs matrix, I also have a (5,) objects vector that’s got how many of each object I have in that scenario. It might look like [2, 3, 1, 0, 6] where I have 2 of the first object, 3 of the second, etc. In this particular rollout, I’d want to sample from my first distribution over locations 2 times, since I have 2 of that object.
In the batch_size = 1 scenario, this works great. I do torch.unbind(probs, dim=-1), and then make a list of Multinomials, with an individual total_count equal to the corresponding number of objects in my objects vector, like this:
cats = [torch.distributions.multinomial.Multinomial(total_count = obj, probs=prob) for obj, prob in zip(objects, torch.unbind(probs, dim=-1))]
Or, I can do something like [cat.sample(n) for n, cat in zip(objects, cats)], if I instead decide to set total_count equal to 1 for each object distribution.
However I’m completely stumped on how to handle the situation where my batch_size > 1. Like, I get an output from my network that’s say, (32 x 10 x 5), and I can again make 5 different categories. But since my objects is now (32 x 5), when I slice to say, obejcts[:,0] I get a slice with a variable number of objects in it, like: [3, 5, 0, ..., 1]. And as far as I can tell, I can’t fit a list or vector into either total_counts or .sample().
So, how can I sample from a Multinomial a variable number of times? Ideally, I’d like to put in a (32 x 10) matrix for my probs, and then sample a number of times defined by a vector of size (32, ). I guess, even better, I’d like to be able to pass a (32 x 10 x 5) probs matrix, and then pass a sample matrix of size (32 x 1 x 5) that defines my samples, but I’d take either. Any ideas?