Extract batch of square submatrices from batch of square matrices

The following piece of code extracts a batch of square submatrices `sub_S` from a batch of square matrices `S`. This piece of code is currently a bottleneck in a larger system. Does PyTorch provide any functionality to get rid of the for loops? So far I didn’t find anything that could improve it.

Here is a minimal working example

``````# Parameters
batch_size = 2
size = 24
n_idx = 6

# Mock data
S = torch.arange(end=batch_size*size*size).reshape(batch_size, size, size)

# Mock indices
indices = list()
for _ in range(batch_size):
indices.append(torch.randint(high=size, size=(n_idx**2, 2)))

sub_S = list()
for b in range(batch_size):
sub_S.append(S[b][indices[b][:, 0], indices[b][:, 1]].reshape(n_idx, n_idx))

sub_S = torch.stack(sub_S)
``````

Hi Samue1,

To eliminate the second for loop in your code, you can use broadcasted indexes in the batch dimension as well. An example is shown below

``````batch_size = 100
size = 24
n_idx = 4

# Mock data
S = torch.arange(end=batch_size*size*size).reshape(batch_size, size, size)
indices = list()
for _ in range(batch_size):
indices.append(torch.randint(high=size, size=(n_idx**2, 2)))

def forlooping(S, indices):
# Mock indices
sub_S = list()
for b in range(batch_size):
sub_S.append(S[b][indices[b][:, 0], indices[b][:, 1]].reshape(n_idx, n_idx))
sub_S = torch.stack(sub_S)
return sub_S

def indexing(S, indices):
indices = torch.stack(indices, dim=0).view(-1, n_idx, n_idx, 2)
bidx = torch.arange(batch_size).view(-1,1,1)
sub_S = S[bidx, indices[...,0], indices[...,1]]
return sub_S
``````

The speed up is noticeable with large batch sizes like so

``````In [2]: %timeit indexing(S, indices)
155 µs ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [3]: %timeit forlooping(S, indices)
1.9 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
``````

To eliminate the first for loop, simply include the batch size in the size parameter like below

``````indices = torch.randint(high=size, size=(batch_size, n_idx**2, 2)))
``````

Hope this helps!

1 Like

Could you explain a bit more the logic behind reshaping `indices` like `(-1, n_idx, n_idx, 2)`? Why does that in combination with `[...,0]` and `[...,1]` work? Thanks!