The following piece of code extracts a batch of square submatrices sub_S
from a batch of square matrices S
. This piece of code is currently a bottleneck in a larger system. Does PyTorch provide any functionality to get rid of the for loops? So far I didn’t find anything that could improve it.
Here is a minimal working example
# Parameters
batch_size = 2
size = 24
n_idx = 6
# Mock data
S = torch.arange(end=batch_size*size*size).reshape(batch_size, size, size)
# Mock indices
indices = list()
for _ in range(batch_size):
indices.append(torch.randint(high=size, size=(n_idx**2, 2)))
sub_S = list()
for b in range(batch_size):
sub_S.append(S[b][indices[b][:, 0], indices[b][:, 1]].reshape(n_idx, n_idx))
sub_S = torch.stack(sub_S)