I want to select corresponding slices from ref_tensor
according to the mean of another tensor a
.
I wonder how to select it fast.
import torch
a = torch.randn(100, 100, 50).mul(2).cuda()
a_mean = torch.mean(a, dim=2, keepdim=True)
ref_tensor = torch.randn(20).cuda()
out_tensor = torch.Tensor(a.size()).cuda()
for i in range(a.size(0)):
for j in range(a.size(1)):
# get corresponding index
# For example, let's just divide it by 0.1.
# The larger value it is, the larger idx we select.
idx = int(a_mean[i,j,:]/0.1)
idx = idx if idx < ref_tensor.size(0)-1 else ref_tensor.size(0)-1
idx = idx if idx > 0 else 0
print i, j, idx
out_tensor[i,j,:] = ref_tensor[idx]