Speed up torch.narrow over multiple dimensions

I want to implement a Machine Learning model that, among other things, should learn coefficient for (Quadratic) spline functions. (B-spline - Wikipedia for reference). As a part of the forward operation the spline function is applied on the input.

For the computation value of this function I implemented the algorithm as described here:

In the initialization step I need to get values out of a tensor of coefficients according to a tensor of indices.
Since I need a range of Indices I used torch.narrow for this problem.
However my runtime is dominated by this computation, so I am searching for ways to speed it up

My current code to measure the runtime looks like this:

import torch
import time
import torch.nn.functional as F

num_weights = 31
batch_size = 24
values_size = 45 * 45
coeffs = torch.rand(num_weights)
values = torch.rand(batch_size, values_size)
interval = torch.arange(0, 1.1, 0.1)

#This is just so the snippet below works without error
index_tensor = torch.searchsorted(interval, values, right=True, side="right") + 2*torch.ones(values.size(), dtype=torch.int64)
pd = (2,2)
interval = interval.unsqueeze(0)
interval = F.pad(interval, pd, "replicate")
interval = interval.squeeze()

start = time.time()
coeff_matrix = torch.stack([torch.stack([torch.narrow(coeffs, 0, index_tensor[j][i] - 2, 3) for i in range(values_size)]) for j in range(batch_size)])
print("Time to build coeff_matrix", time.time() - start)

On my computer I get a runtime of ~0.5 seconds. Just for reference, the other operations in the evaluation amount to ~0.01 seconds
I tried to calculate the matrix first without torch.stack() in the middle, but it was slower.
So my question is, if there is a faster method to compute my coeff_matrix I thought I could usetorch.gather(),but I am unsure about how to do it.
Additionally I thought I could abuse the fact that the coefficients I want to take depend on i and j linearly, but I don’t know how.
The coefficients are also parameters in the real application, so it is important to not break the gradient.

My Idea about using torch.gather() worked out fine for me. Here is the solution I used:

bigger_coeffs = coeffs.unsqueeze(0).unsqueeze(0)

bigger_coeffs = coeffs.expand(batch_size, values_size, -1)

bigger_index_tensor = index_tensor.unsqueeze(-1).expand(index_tensor.shape[0], index_tensor.shape[1], 3)

sub = torch.tensor([2, 1, 0])

index_tensor_res = bigger_index_tensor - sub[None, None, :]

coeff_matrix = torch.gather(bigger_coeffs, 2, index_tensor_res)