How to sample from Multinomial a variable number of times?

I’m trying to use a torch.distributions.multinomial.Multinomial distribution for a reinforcement learning task. Essentially, my network outputs probabilities over actions, which I want to feed into a Multinomial and then sample from.

My network output looks something like (locations x objects), where each of 5 different object can be placed in 1 of 10 different locations. This is all fine and good.

The problem I have is that in different rollouts, I might have a variable number of each object. So in addition to my (10 x 5) probs matrix, I also have a (5,) objects vector that’s got how many of each object I have in that scenario. It might look like [2, 3, 1, 0, 6] where I have 2 of the first object, 3 of the second, etc. In this particular rollout, I’d want to sample from my first distribution over locations 2 times, since I have 2 of that object.

In the batch_size = 1 scenario, this works great. I do torch.unbind(probs, dim=-1), and then make a list of Multinomials, with an individual total_count equal to the corresponding number of objects in my objects vector, like this:

cats = [torch.distributions.multinomial.Multinomial(total_count = obj, probs=prob) for obj, prob in zip(objects, torch.unbind(probs, dim=-1))]

Or, I can do something like [cat.sample(n) for n, cat in zip(objects, cats)], if I instead decide to set total_count equal to 1 for each object distribution.

However I’m completely stumped on how to handle the situation where my batch_size > 1. Like, I get an output from my network that’s say, (32 x 10 x 5), and I can again make 5 different categories. But since my objects is now (32 x 5), when I slice to say, obejcts[:,0] I get a slice with a variable number of objects in it, like: [3, 5, 0, ..., 1]. And as far as I can tell, I can’t fit a list or vector into either total_counts or .sample().

So, how can I sample from a Multinomial a variable number of times? Ideally, I’d like to put in a (32 x 10) matrix for my probs, and then sample a number of times defined by a vector of size (32, ). I guess, even better, I’d like to be able to pass a (32 x 10 x 5) probs matrix, and then pass a sample matrix of size (32 x 1 x 5) that defines my samples, but I’d take either. Any ideas?

Hi @drj,

Could you share a minimal reproducible example of your issue?

If you have a function that can generate a sample for a batch size of 1, you could try using torch.func.vmap to vectorize over this function to generate data with batch size > 1. The docs for torch.func.vmap are here.

Hi @AlphaBetaGamma96

Thanks for the suggestion. I haven’t looked at torch.func.vmap yet, so maybe that will solve my problem.

Here’s a short script demonstrating what I’m trying to do:

import torch

# our goal is to place N vehicles in L locations.
# to simplify, lets assume we have 3 vehicle-types: car, buses, bikes
# we have 4 potential locations
# we get a probabily distribution like (Locations x Vehicles):
#   cars
# [[.08, .7, .25], <- location0
#  [.9, .01, .25],
#  [.01, .25, .25],
#  [.01, .04, .25]]

# we also get a vector of vehicle counts like [1, 3, 0]
# so our output (Locations x Vehicles), after sampling, might look like:
# [[0, 2, 0],
#  [1, 0, 0],
#  [0, 1, 0],
#  [0, 0, 0]]

# batch_size = 1 case
probs1 = torch.arange(4 * 3).reshape(4, 3)
vehicles1 = [1, 3, 1]


def smpl(probs, vehicles):
    cats = [
        torch.distributions.multinomial.Multinomial(total_count=n, probs=prob)
        for n, prob in zip(vehicles, torch.unbind(probs, dim=-1))
    ]
    sample_ = torch.stack([cat.sample() for cat in cats], dim=-1)
    return sample_


# can also do
def smpl2(probs, vehicles):
    cats = [
        torch.distributions.multinomial.Multinomial(probs=prob)
        for prob in torch.unbind(probs, dim=-1)
    ]
    sample_ = torch.stack(
        [cat.sample((n,)).sum(dim=0) for n, cat in zip(vehicles, cats)], dim=-1
    )
    return sample_


# batch_size = 5 case
probs5 = torch.arange(5 * 4 * 3).reshape(5, 4, 3)
vehicles5 = torch.arange(5 * 3).reshape(5, 3)


# now the first approach doesn't work  because total_counts can only be an int
def smpl_batch(probs, vehicles):
    cats = [
        torch.distributions.multinomial.Multinomial(total_counts=n, probs=prob)
        for n, prob in zip(torch.unbind(vehicles, dim=-1), torch.unbind(probs, dim=-1))
    ]
    sample_ = torch.stack([cat.sample() for cat in cats], dim=-1)
    return sample_


# ERROR


# and the second approach might work if i could figure out the right slicing on something like:
def smpl2_batch(probs, vehicles):
    cats = [
        torch.distributions.multinomial.Multinomial(probs=prob)
        for prob in torch.unbind(probs, dim=-1)
    ]
    sample_ = torch.stack(
        [cat.sample((max(n),)) for n, cat in zip(torch.unbind(vehicles, dim=-1), cats)],
        dim=-1,
    )
    # do some crazy slicing/summing here on sample_ and vehicles
    return sample_


# but i can't figure out how the slicing/summing would work to get back to a (batch x locations x vehicles) size

if __name__ == "__main__":
    print("Batch Size: 1")
    print(f"Probs: {probs1}")
    print(f"Vehicles: {vehicles1}")
    print(f"Result first approach: {smpl(probs1, vehicles1)}\n")
    print(f"Result second approach: {smpl2(probs1, vehicles1)}\n")
    print("Batch Size: 5")
    print(f"Probs: {probs5}")
    print(f"Vehicles: {vehicles5}")
    print(f"Result: {smpl2_batch(probs5, vehicles5)}")

At the end of the day, I’d really like to be able to pass a 1D vector to total_count that matches the innermost dimension of probs or logits, representing how many times to sample from that distribution.

For anyone finding this in the future, I figured out the slicing. This works for me:

def smpl2_batch(probs, vehicles):
    cats = [
        torch.distributions.multinomial.Multinomial(probs=prob)
        for prob in torch.unbind(probs, dim=-1)
    ]
    vehicles = torch.unbind(vehicles, dim=-1)
    arr = [cat.sample((max(v),)) for v, cat in zip(vehicles, cats)
    return torch.stack(
        [torch.sum(
            (torch.arange(arr.size(0)) < v.unsqueeze(-1)).T.unsqueeze(-1)
            * arr,
            dim=1)
        for v, arr in zip(vehicles, arr)
        ],
        dim=-1,
    )

I’ll admit it’s pretty ugly, but it essentially samples as many times as we have vehicles, for each vehicle-type: arr = [cat.sample((max(v),)) for v, cat in zip(vehicles, cats).

Then, it builds a mask for each vehicle-type: torch.arange(arr.size(0)), and sets all the samples to zero that are greater then the count of vehicles for that type: v.unsqueeze(-1).

Then we manipulate that mask so it can broadcast to the (samples x batch x locations) matrix: .T.unsqueeze(-1), multiply them, and sum() along the samples dimension to collapse it to shape (batch x locations).

Repeat for each vehicle-type, stack the matrices along the last axis to get (batch x locations x vehicles).

1 Like