Keep off-diagonal elements only from square matrix

I have a (square) pairwise distances matrix, say n x n, and I would like to get rid of the diagonal and hence work with a n x (n - 1) matrix. What’s a fast way of doing that?

One way is a.masked_select(~torch.eye(n, dtype=bool)).view(n, n - 1) but I was curious if there’s a faster approach.

1 Like

I do not know if it is really faster, but just for the fun of it you could try as an alternative:

a.flatten()[1:].view(n-1, n+1)[:,:-1].reshape(n, n-1)
6 Likes

This is indeed much faster!

Even faster is

a * (1 - torch.eye(n))

For those who want a visualization of how this works:

a = [
    [1,2,3]
    [4,5,6]
    [7,8,9]
]
# flatten
a.flatten() #  [1,2,3,4,5,6,7,8,9]
a.flatten()[1:]  # [2,3,4,5,6,7,8,9] # removes the first diagonal

# view 
a.flatten()[1:].view(n-1, n+1) # takes a snaking line starting from after a diagonal, ending in a diagonal
# [
#    [2,3,4,5]
#    [6,7,8,9]
# ]

# cut and reshape
a.flatten()[1:].view(n-1, n+1)[:, :-1] # removes last element from every row, which is a diagonal
# [
#    [2,3,4]
#    [6,7,8]
# ]