How can I efficiently implement this code? Implementing without the loop.
import torch
torch.manual_seed(0)
tensor = torch.rand(2,4,3,3)
def sort_tensor(tensor):
# average pooling
polling = torch.nn.AdaptiveAvgPool2d(1)
tensor_pooled = polling(tensor)
# sorting the averaged pooled values and getting the indexes
sorted_tensor_pooled, sorted_indx = torch.sort(tensor_pooled, 1, descending=True)
sorted_indx = sorted_indx.squeeze()
#sorting the orginal tensor based on the sroting index
for i in range(len(sorted_indx)):
if i == 0:
sorted_tensor = tensor[i, sorted_indx[i], :, :].unsqueeze(dim=0)
else:
sorted_tensor = torch.cat((sorted_tensor, tensor[i, sorted_indx[i], :, :].unsqueeze(dim=0)), dim=0)
return sorted_tensor
print('sorted', sort_tensor(tensor))
There is no pytorch function that can do what you are looking for. So you have to use a for loop. Your implementation can be made more efficient by creating a new tensor and then copying the contents to it, rather than increasing the current tensor by using cat.
def sort_tensor(tensor):
tensor_pooled = nn.AdaptiveAvgPool2d(1)(tensor).squeeze()
_, sorted_indx = torch.sort(tensor_pooled, 1, descending=True)
# Create a new tensor where the result would be stored
sorted_tensor = torch.empty(tensor.size())
for dim_0 in range(tensor_pooled.shape[0]):
for dim_1 in range(tensor_pooled.shape[1]):
sorted_tensor[dim_0, dim_1] = tensor[dim_0, sorted_indx[dim_0, dim_1]]
return sorted_tensor