Select slices from a tensor fast

I want to select corresponding slices from ref_tensor according to the mean of another tensor a.
I wonder how to select it fast.

import torch
a = torch.randn(100, 100, 50).mul(2).cuda()
a_mean = torch.mean(a, dim=2, keepdim=True)
ref_tensor = torch.randn(20).cuda()
out_tensor = torch.Tensor(a.size()).cuda()

for i in range(a.size(0)):
    for j in range(a.size(1)):
        # get corresponding index
        # For example, let's just divide it by 0.1.
        # The larger value it is, the larger idx we select.
        idx = int(a_mean[i,j,:]/0.1)
        idx = idx if idx < ref_tensor.size(0)-1 else ref_tensor.size(0)-1
        idx = idx if idx > 0 else 0
        print i, j, idx
        out_tensor[i,j,:] = ref_tensor[idx]

For someone who meet the same problem, this is a snippet that may help

        a_mean = a_mean/0.1
        a_mean = a_mean.long().view(h*w)
        a_mean = torch.clamp(a_mean, min=0, max=ref_tensor.size(0))
        a_mean = a_mean.view(-1)
        out_tensor = ref_tensor.index_select(dim=0,