I’m trying to replace the value of a tensor with values that represent their increasing order along a specific dimension, as in the following example for dimension 1 with a 3D tensor:
a = tensor([[[14., 99., 77.],
[61., 22., 56.],
[18., 8., 45.]],
[[94., 44., 7.],
[ 4., 80., 87.],
[83., 32., 85.]]])
Expected (when sorting along the dim = 1)
tensor([[[0., 2., 2.],
[2., 1., 1.],
[1., 0., 0.]],
[[2., 1., 0.],
[0., 2., 2.],
[1., 0., 1.]]])
The following code works correctly and obtains the expected results, but only for 3D matrices and according to dimension 1:
shape_to_use = [2,3,3]
dim = 1;
tensor = torch.randint(0,100, shape_to_use,dtype=torch.float);
all_values = torch.arange(0, tensor.shape[dim], dtype=torch.float);
sorted_, indices = torch.sort(tensor, dim);
for i in xrange(0, tensor.shape[0]):
for j in xrange(0, tensor.shape[2]):
tensor[i,indices[i,:,j],j] = all_values;
The only problem is that it depends on the number of dimensions of the tensor and the selected dimension. For a 4D tensor, it will need the following code for assigning values:
for i in xrange(0, tensor.shape[0]):
for j in xrange(0, tensor.shape[2]):
for k in xrange(0, tensor.shape[3]):
tensor[i,indices[i,:,j,k],j,k] = all_values;
I was wondering if there was a better way to do it, ideally without using a loop or without needing a specific code for all possible dimensions (for the sort and shape of the tensor).