An efficient way to slice a torch tensor

How can I efficiently implement this code? Implementing without the loop.


import torch
torch.manual_seed(0)
tensor = torch.rand(2,4,3,3)
def sort_tensor(tensor):
        # average pooling
        polling = torch.nn.AdaptiveAvgPool2d(1)
        tensor_pooled = polling(tensor)
        # sorting the averaged pooled values and getting the indexes 
        sorted_tensor_pooled, sorted_indx = torch.sort(tensor_pooled, 1, descending=True)
        sorted_indx = sorted_indx.squeeze()
        
        #sorting the orginal tensor based on the sroting index
        for i in range(len(sorted_indx)):
            if i == 0:
                sorted_tensor = tensor[i, sorted_indx[i], :, :].unsqueeze(dim=0)
            else:
                sorted_tensor = torch.cat((sorted_tensor, tensor[i, sorted_indx[i], :, :].unsqueeze(dim=0)), dim=0)

        return sorted_tensor
print('sorted', sort_tensor(tensor))

There is no pytorch function that can do what you are looking for. So you have to use a for loop. Your implementation can be made more efficient by creating a new tensor and then copying the contents to it, rather than increasing the current tensor by using cat.

def sort_tensor(tensor):
    tensor_pooled = nn.AdaptiveAvgPool2d(1)(tensor).squeeze()
    
    _, sorted_indx = torch.sort(tensor_pooled, 1, descending=True)
    
    # Create a new tensor where the result would be stored
    sorted_tensor = torch.empty(tensor.size())
    for dim_0 in range(tensor_pooled.shape[0]):
        for dim_1 in range(tensor_pooled.shape[1]):
            sorted_tensor[dim_0, dim_1] = tensor[dim_0, sorted_indx[dim_0, dim_1]]
    
    return sorted_tensor

Hi @Abe , I have converted the loop into a vectorized format.
I have verified with multiple inputs, and it works much faster.
Here’s the code.

def sort_tensorV2(tensor):
        # average pooling
        polling = torch.nn.AdaptiveAvgPool2d(1)
        tensor_pooled = polling(tensor)
        print("Polling shape ", tensor_pooled.shape)
        # sorting the averaged pooled values and getting the indexes 
        # print("Tensor-pooled ", tensor_pooled)
        sorted_tensor_pooled, sorted_indx = torch.sort(tensor_pooled, 1, descending=True)
        sorted_indx = sorted_indx.squeeze()

        a, b, c, d = tensor.shape
        mults = np.arange(0,a*b,b).reshape(-1,1)
        temp_indices = (mults+sorted_indx).flatten()
        temp = tensor.reshape(a*b,c,d)
        final = temp[temp_indices,:,:]
        sorted_tensor = final.reshape(a,b,c,d)
        return sorted_tensor

Hope it helps.!