Extract batch of square submatrices from batch of square matrices

The following piece of code extracts a batch of square submatrices sub_S from a batch of square matrices S. This piece of code is currently a bottleneck in a larger system. Does PyTorch provide any functionality to get rid of the for loops? So far I didn’t find anything that could improve it.

Here is a minimal working example

# Parameters
batch_size = 2
size = 24
n_idx = 6

# Mock data
S = torch.arange(end=batch_size*size*size).reshape(batch_size, size, size)

# Mock indices
indices = list()
for _ in range(batch_size):
    indices.append(torch.randint(high=size, size=(n_idx**2, 2)))

sub_S = list()
for b in range(batch_size):
    sub_S.append(S[b][indices[b][:, 0], indices[b][:, 1]].reshape(n_idx, n_idx))

sub_S = torch.stack(sub_S)

Hi Samue1,

To eliminate the second for loop in your code, you can use broadcasted indexes in the batch dimension as well. An example is shown below

batch_size = 100
size = 24
n_idx = 4

# Mock data
S = torch.arange(end=batch_size*size*size).reshape(batch_size, size, size)
indices = list()
for _ in range(batch_size):
    indices.append(torch.randint(high=size, size=(n_idx**2, 2)))

def forlooping(S, indices):
    # Mock indices
    sub_S = list()
    for b in range(batch_size):
        sub_S.append(S[b][indices[b][:, 0], indices[b][:, 1]].reshape(n_idx, n_idx))
    sub_S = torch.stack(sub_S)
    return sub_S

def indexing(S, indices):
    indices = torch.stack(indices, dim=0).view(-1, n_idx, n_idx, 2)
    bidx = torch.arange(batch_size).view(-1,1,1)
    sub_S = S[bidx, indices[...,0], indices[...,1]]
    return sub_S

The speed up is noticeable with large batch sizes like so

In [2]: %timeit indexing(S, indices)
155 µs ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [3]: %timeit forlooping(S, indices)
1.9 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

To eliminate the first for loop, simply include the batch size in the size parameter like below

indices = torch.randint(high=size, size=(batch_size, n_idx**2, 2)))

Hope this helps!

1 Like

Could you explain a bit more the logic behind reshaping indices like (-1, n_idx, n_idx, 2)? Why does that in combination with [...,0] and [...,1] work? Thanks!