How can the following be vectorized in pytorch?
For each n
I want to select one of the s
options that have d
dimensions each.
The naive for-loop looks like this
def independent_sample(data):
n, d, s = data.shape
data = torch.stack([p[:, torch.randint(s, (1,))] for p in data])
data = data.squeeze(-1)
return data