Inpedentent sampling along tensor dimension

How can the following be vectorized in pytorch?

For each n I want to select one of the s options that have d dimensions each.
The naive for-loop looks like this

    def independent_sample(data):
        n, d, s = data.shape
        data = torch.stack([p[:, torch.randint(s, (1,))] for p in data])
        data = data.squeeze(-1)
        return data

answering myself

data_ = data[range(n), :, torch.radnint(s, (n,))]
1 Like