Hi All,
I was wondering what’s the most efficient way to grab the off-diagonal elements of a batch of matrices?
Let’s assume I have some Tensor of shape [B,N,N]
and wish to grab all off-diagonal elements (set all diagonal elements to 0) I’m currently using this,
def get_off_diagonal_elements(M):
dim=len(M.shape)-2
if(M.shape[-2] != M.shape[-1]):
raise ValueError("Matrix error")
mask = (1 - torch.eye(M.shape[-1], device=M.device))
return M*mask
As I was typing out this question I had the idea of just copying the Tensor and filling the diagonal with 0s, however, I tried the following below and it failed with the following error
def get_off_diagonal_elements_fast(M):
return M.clone().fill_diagonal_(0)
Also, do I need to call .clone()
when returning the Tensor? The error is here,
RuntimeError: all dimensions of input must be of equal length
Any help would be appreciated!