I’m trying to use a torch.distributions.multinomial.Multinomial
distribution for a reinforcement learning task. Essentially, my network outputs probabilities over actions, which I want to feed into a Multinomial
and then sample from.
My network output looks something like (locations x objects)
, where each of 5 different object
can be placed in 1 of 10 different locations
. This is all fine and good.
The problem I have is that in different rollouts, I might have a variable number of each object
. So in addition to my (10 x 5)
probs
matrix, I also have a (5,)
objects
vector that’s got how many of each object I have in that scenario. It might look like [2, 3, 1, 0, 6]
where I have 2 of the first object, 3 of the second, etc. In this particular rollout, I’d want to sample from my first distribution over locations
2 times, since I have 2 of that object
.
In the batch_size = 1
scenario, this works great. I do torch.unbind(probs, dim=-1)
, and then make a list of Multinomials
, with an individual total_count
equal to the corresponding number of objects in my objects
vector, like this:
cats = [torch.distributions.multinomial.Multinomial(total_count = obj, probs=prob) for obj, prob in zip(objects, torch.unbind(probs, dim=-1))]
Or, I can do something like [cat.sample(n) for n, cat in zip(objects, cats)]
, if I instead decide to set total_count
equal to 1 for each object distribution.
However I’m completely stumped on how to handle the situation where my batch_size > 1
. Like, I get an output from my network that’s say, (32 x 10 x 5)
, and I can again make 5 different categories. But since my objects
is now (32 x 5)
, when I slice to say, obejcts[:,0
] I get a slice with a variable number of objects
in it, like: [3, 5, 0, ..., 1]
. And as far as I can tell, I can’t fit a list or vector into either total_counts
or .sample()
.
So, how can I sample from a Multinomial a variable number of times? Ideally, I’d like to put in a (32 x 10
) matrix for my probs
, and then sample a number of times defined by a vector of size (32, )
. I guess, even better, I’d like to be able to pass a (32 x 10 x 5)
probs
matrix, and then pass a sample matrix of size (32 x 1 x 5)
that defines my samples, but I’d take either. Any ideas?