I have a (square) pairwise distances matrix, say n x n, and I would like to get rid of the diagonal and hence work with a n x (n - 1) matrix. What’s a fast way of doing that?
One way is a.masked_select(~torch.eye(n, dtype=bool)).view(n, n - 1) but I was curious if there’s a faster approach.
For those who want a visualization of how this works:
a = [
[1,2,3]
[4,5,6]
[7,8,9]
]
# flatten
a.flatten() # [1,2,3,4,5,6,7,8,9]
a.flatten()[1:] # [2,3,4,5,6,7,8,9] # removes the first diagonal
# view
a.flatten()[1:].view(n-1, n+1) # takes a snaking line starting from after a diagonal, ending in a diagonal
# [
# [2,3,4,5]
# [6,7,8,9]
# ]
# cut and reshape
a.flatten()[1:].view(n-1, n+1)[:, :-1] # removes last element from every row, which is a diagonal
# [
# [2,3,4]
# [6,7,8]
# ]