I want to implement a Machine Learning model that, among other things, should learn coefficient for (Quadratic) spline functions. (B-spline - Wikipedia for reference). As a part of the forward operation the spline function is applied on the input.
For the computation value of this function I implemented the algorithm as described here:
In the initialization step I need to get values out of a tensor of coefficients according to a tensor of indices.
Since I need a range of Indices I used
torch.narrow for this problem.
However my runtime is dominated by this computation, so I am searching for ways to speed it up
My current code to measure the runtime looks like this:
import torch import time import torch.nn.functional as F num_weights = 31 batch_size = 24 values_size = 45 * 45 coeffs = torch.rand(num_weights) values = torch.rand(batch_size, values_size) interval = torch.arange(0, 1.1, 0.1) #This is just so the snippet below works without error index_tensor = torch.searchsorted(interval, values, right=True, side="right") + 2*torch.ones(values.size(), dtype=torch.int64) pd = (2,2) interval = interval.unsqueeze(0) interval = F.pad(interval, pd, "replicate") interval = interval.squeeze() start = time.time() coeff_matrix = torch.stack([torch.stack([torch.narrow(coeffs, 0, index_tensor[j][i] - 2, 3) for i in range(values_size)]) for j in range(batch_size)]) print("Time to build coeff_matrix", time.time() - start)
On my computer I get a runtime of ~0.5 seconds. Just for reference, the other operations in the evaluation amount to ~0.01 seconds
I tried to calculate the matrix first without
torch.stack() in the middle, but it was slower.
So my question is, if there is a faster method to compute my
coeff_matrix I thought I could use
torch.gather(),but I am unsure about how to do it.
Additionally I thought I could abuse the fact that the coefficients I want to take depend on i and j linearly, but I don’t know how.
The coefficients are also parameters in the real application, so it is important to not break the gradient.