Help Optimizing a PyTorch Loop with Advanced Indexing

Hey everyone,

I’m working on optimizing a PyTorch operation by eliminating a for loop and using advanced indexing instead. My current implementation involves iterating over a dimension of my binned_data tensor and using the resulting indices to select corresponding weights from the self.weights tensor. Here’s a quick overview of my current setup:

Tensor Shapes:

  • binned_data: torch.Size([2048, 50, 149])
  • self.weights: torch.Size([50, 150, 149])


Example Data Point

out = torch.zeros(size=(binned_data.shape[0],), dtype=torch.float32)
arange = torch.arange(0,self.weights.shape[0])
for kernel in range(binned_data.shape[2]): 
     selected_index = binned_data[:, :, kernel]  
     selected_kernel = self.weights[:, :, kernel]
     selected_values = selected_kernel[arange, selected_index, arange]
     out += selected_values.sum(dim=1)

Objective:
I want to replace the for loop with an advanced indexing operation to achieve the same result but more efficiently. The goal is to perform the entire operation in one step without sacrificing performance.

If anyone has experience with this type of optimization or can suggest a better way to implement this using PyTorch’s advanced indexing, I would greatly appreciate your input!

Thanks in advance!

This should work (with 0.1 tolerance).

import torch

# Generate random tensor `binned_data` with dimensions (2048, 50, 149).
# Values are integers between 0 and 149 (inclusive).
binned_data = torch.randint(0, 150, (2048, 50, 149), dtype=torch.long)

# Generate random weights `self_weights` with dimensions (50, 150, 149).
# Values are drawn from a standard normal distribution (mean=0, std=1).
self_weights = torch.randn(50, 150, 149, dtype=torch.float32)

# Initialize an output tensor `out_loop` of size (2048,)
# to accumulate the results from the loop-based computation.
out_loop = torch.zeros(binned_data.shape[0], dtype=torch.float32)

# Loop over each kernel index (range from 0 to 148, as `binned_data.shape[2]` is 149).
for kernel in range(binned_data.shape[2]):
    # For each kernel index, extract the corresponding indices from `binned_data`.
    # `selected_index` has shape (2048, 50) as it selects a 2D slice from `binned_data`.
    selected_index = binned_data[:, :, kernel]
    
    # Extract the kernel slice from `self_weights` corresponding to the current kernel index.
    # `selected_kernel` has shape (50, 150) as it selects a 2D slice from `self_weights`.
    selected_kernel = self_weights[:, :, kernel]
    
    # Select the values from `selected_kernel` using the indices in `selected_index`.
    # The `torch.arange(selected_kernel.shape[0])` creates an index for each channel (size 50),
    # and `selected_index` determines the column in `selected_kernel` to select for each batch and channel.
    # `selected_values` has shape (2048, 50) after gathering the selected weights.
    selected_values = selected_kernel[torch.arange(selected_kernel.shape[0]), selected_index]
    
    # Sum the selected values across the channels (dim=1), resulting in a tensor of shape (2048,).
    # Accumulate the result into `out_loop`.
    out_loop += selected_values.sum(dim=1)

# --- Advanced Indexing Method ---

# Extract the batch size, number of channels, and number of kernels from `binned_data`.
batch_size, num_channels, num_kernels = binned_data.shape

# Create a range tensor of size (1, 50, 1) to index across the channels.
# This will be used to index the first dimension of `self_weights`.
arange = torch.arange(num_channels).unsqueeze(0).unsqueeze(-1)

# Perform advanced indexing to gather the selected values.
# `arange` indexes the channels in `self_weights`.
# `binned_data` provides the indices for selecting from the second dimension of `self_weights`.
# `torch.arange(num_kernels).unsqueeze(0).unsqueeze(0)` indexes the kernels in `self_weights`.
# `selected_values` will have shape (2048, 50, 149), corresponding to the selected values for each batch, channel, and kernel.
selected_values = self_weights[
    arange,
    binned_data,
    torch.arange(num_kernels).unsqueeze(0).unsqueeze(0)
]

# Sum the selected values across the channels (dim=1) and kernels (dim=2).
# This results in a tensor of shape (2048,) corresponding to the summed results for each batch.
out_advanced = selected_values.sum(dim=(1, 2))

# Check if the results from the loop-based approach (`out_loop`) and the advanced indexing approach (`out_advanced`) match.
# Allow a relative tolerance of 0.1 to account for minor differences due to floating-point operations.
assert torch.allclose(out_loop, out_advanced, rtol=0.1), "The outputs do not match!"
print("Outputs match!")

Note: Used LLM for comments to make life easier for everyone haha