Good news : I believe you can now do this in pytorch 1.11.0
import torch
def remap_values(remapping, x):
index = torch.bucketize(x.ravel(), remapping[0])
return remapping[1][index].reshape(x.shape)
remapping = torch.arange(0, 256).cuda(), torch.randperm(256).cuda()
images_batch = torch.randint(0, 256, (16, 224, 224, 3)).cuda()
remapped_batch = remap_values(remapping, images_batch)