Most efficient way to get just the off-diagonal elements of a Tensor?

Hi All,

I was wondering what’s the most efficient way to grab the off-diagonal elements of a batch of matrices?

Let’s assume I have some Tensor of shape [B,N,N] and wish to grab all off-diagonal elements (set all diagonal elements to 0) I’m currently using this,

def get_off_diagonal_elements(M):
  if(M.shape[-2] != M.shape[-1]):
    raise ValueError("Matrix error")
  mask = (1 - torch.eye(M.shape[-1], device=M.device)) 
  return M*mask

As I was typing out this question I had the idea of just copying the Tensor and filling the diagonal with 0s, however, I tried the following below and it failed with the following error

def get_off_diagonal_elements_fast(M):
  return M.clone().fill_diagonal_(0)

Also, do I need to call .clone() when returning the Tensor? The error is here,

RuntimeError: all dimensions of input must be of equal length

Any help would be appreciated!

You need the clone if you want to keep M unchanged.

def get_off_diagonal_elements(M):
    res = M.clone()
    res.diagonal(dim1=-1, dim2=-2).zero_()
    return res

Perfect, thank you! Could I ask a quick follow-up? I assume the .clone() call is due to shallow copying issues?

Yes,but the terminology here is views. Diagonal gets the diagonal as a view of the same memory blob as its Tensor instance and methods ending with _ modify their Tensor instance.

Best regards


1 Like