I was wondering what’s the most efficient way to grab the off-diagonal elements of a batch of matrices?
Let’s assume I have some Tensor of shape
[B,N,N] and wish to grab all off-diagonal elements (set all diagonal elements to 0) I’m currently using this,
def get_off_diagonal_elements(M): dim=len(M.shape)-2 if(M.shape[-2] != M.shape[-1]): raise ValueError("Matrix error") mask = (1 - torch.eye(M.shape[-1], device=M.device)) return M*mask
As I was typing out this question I had the idea of just copying the Tensor and filling the diagonal with 0s, however, I tried the following below and it failed with the following error
def get_off_diagonal_elements_fast(M): return M.clone().fill_diagonal_(0)
Also, do I need to call
.clone() when returning the Tensor? The error is here,
RuntimeError: all dimensions of input must be of equal length
Any help would be appreciated!