I want to implement a Machine Learning model that, among other things, should learn coefficient for (Quadratic) spline functions. (B-spline - Wikipedia for reference). As a part of the forward operation the spline function is applied on the input.

For the computation value of this function I implemented the algorithm as described here:

In the initialization step I need to get values out of a tensor of coefficients according to a tensor of indices.

Since I need a range of Indices I used `torch.narrow`

for this problem.

However my runtime is dominated by this computation, so I am searching for ways to speed it up

My current code to measure the runtime looks like this:

```
import torch
import time
import torch.nn.functional as F
num_weights = 31
batch_size = 24
values_size = 45 * 45
coeffs = torch.rand(num_weights)
values = torch.rand(batch_size, values_size)
interval = torch.arange(0, 1.1, 0.1)
#This is just so the snippet below works without error
index_tensor = torch.searchsorted(interval, values, right=True, side="right") + 2*torch.ones(values.size(), dtype=torch.int64)
pd = (2,2)
interval = interval.unsqueeze(0)
interval = F.pad(interval, pd, "replicate")
interval = interval.squeeze()
start = time.time()
coeff_matrix = torch.stack([torch.stack([torch.narrow(coeffs, 0, index_tensor[j][i] - 2, 3) for i in range(values_size)]) for j in range(batch_size)])
print("Time to build coeff_matrix", time.time() - start)
```

On my computer I get a runtime of ~0.5 seconds. Just for reference, the other operations in the evaluation amount to ~0.01 seconds

I tried to calculate the matrix first without `torch.stack()`

in the middle, but it was slower.

So my question is, if there is a faster method to compute my `coeff_matrix`

I thought I could use`torch.gather(),`

but I am unsure about how to do it.

Additionally I thought I could abuse the fact that the coefficients I want to take depend on i and j linearly, but I don’t know how.

The coefficients are also parameters in the real application, so it is important to not break the gradient.