How can the following be vectorized in pytorch?
n I want to select one of the
s options that have
d dimensions each.
The naive for-loop looks like this
def independent_sample(data): n, d, s = data.shape data = torch.stack([p[:, torch.randint(s, (1,))] for p in data]) data = data.squeeze(-1) return data