This should work (with 0.1 tolerance).
import torch
# Generate random tensor `binned_data` with dimensions (2048, 50, 149).
# Values are integers between 0 and 149 (inclusive).
binned_data = torch.randint(0, 150, (2048, 50, 149), dtype=torch.long)
# Generate random weights `self_weights` with dimensions (50, 150, 149).
# Values are drawn from a standard normal distribution (mean=0, std=1).
self_weights = torch.randn(50, 150, 149, dtype=torch.float32)
# Initialize an output tensor `out_loop` of size (2048,)
# to accumulate the results from the loop-based computation.
out_loop = torch.zeros(binned_data.shape[0], dtype=torch.float32)
# Loop over each kernel index (range from 0 to 148, as `binned_data.shape[2]` is 149).
for kernel in range(binned_data.shape[2]):
# For each kernel index, extract the corresponding indices from `binned_data`.
# `selected_index` has shape (2048, 50) as it selects a 2D slice from `binned_data`.
selected_index = binned_data[:, :, kernel]
# Extract the kernel slice from `self_weights` corresponding to the current kernel index.
# `selected_kernel` has shape (50, 150) as it selects a 2D slice from `self_weights`.
selected_kernel = self_weights[:, :, kernel]
# Select the values from `selected_kernel` using the indices in `selected_index`.
# The `torch.arange(selected_kernel.shape[0])` creates an index for each channel (size 50),
# and `selected_index` determines the column in `selected_kernel` to select for each batch and channel.
# `selected_values` has shape (2048, 50) after gathering the selected weights.
selected_values = selected_kernel[torch.arange(selected_kernel.shape[0]), selected_index]
# Sum the selected values across the channels (dim=1), resulting in a tensor of shape (2048,).
# Accumulate the result into `out_loop`.
out_loop += selected_values.sum(dim=1)
# --- Advanced Indexing Method ---
# Extract the batch size, number of channels, and number of kernels from `binned_data`.
batch_size, num_channels, num_kernels = binned_data.shape
# Create a range tensor of size (1, 50, 1) to index across the channels.
# This will be used to index the first dimension of `self_weights`.
arange = torch.arange(num_channels).unsqueeze(0).unsqueeze(-1)
# Perform advanced indexing to gather the selected values.
# `arange` indexes the channels in `self_weights`.
# `binned_data` provides the indices for selecting from the second dimension of `self_weights`.
# `torch.arange(num_kernels).unsqueeze(0).unsqueeze(0)` indexes the kernels in `self_weights`.
# `selected_values` will have shape (2048, 50, 149), corresponding to the selected values for each batch, channel, and kernel.
selected_values = self_weights[
arange,
binned_data,
torch.arange(num_kernels).unsqueeze(0).unsqueeze(0)
]
# Sum the selected values across the channels (dim=1) and kernels (dim=2).
# This results in a tensor of shape (2048,) corresponding to the summed results for each batch.
out_advanced = selected_values.sum(dim=(1, 2))
# Check if the results from the loop-based approach (`out_loop`) and the advanced indexing approach (`out_advanced`) match.
# Allow a relative tolerance of 0.1 to account for minor differences due to floating-point operations.
assert torch.allclose(out_loop, out_advanced, rtol=0.1), "The outputs do not match!"
print("Outputs match!")
Note: Used LLM for comments to make life easier for everyone haha