Simple masking for diagonals

Hello everyone,

I have a simple issue but I am not sure how to deal with.

Let’s say I have a tensor (B, N, N, F) where B is the Batch Size, N the dimensions of a square matrix and F the dimension of my feature space. What is the most straight forward way to mask the diagonals with 0?

1 Like

You could try the following:

a = torch.randn(10, 10)
step = 1 + (np.cumprod(a.shape[:-1])).sum()
a.view(-1)[::step] = 0
a.view(10, 10)

EDIT: Sorry, I just realized, this formula might lead to wrong results, when the dimension sizes differ.
What is your definition of diagonal? Each diagonal of the square matrix [N, N]?

1 Like

Hello Kim,

how about the following?

x = torch.randn(2,5,5,1)
idx = torch.arange(0,5, out=torch.LongTensor())
x[:,idx,idx] = 0

Best regards


P.S.: I’m not sure about the performance characteristics of this vs. Peter’s solution.


Hi Patrick,
and thank for your answer!!

Yes, I meant each diagonal on my F [N,N] matrices for each batch.

But let’s look at the question another way. Let’s see I have a matrix M=[N,N] with zeros somewhere in there, that I want to use as a mask for all the F matrices. The purpose is to set to zero the element where in every matrix for the same indices where my matrix M has zeros. Is there an efficient way to do that in pytorch? :slight_smile:

I hope I explained it well, otherwise I will try to address the question better

As far as I understand you problem, you have two matrices. Let’s name them m = [N, N] and x = [batch_size, N, N, F].
Somewhere in m are zeros and you would like to get the indices.
Using these indices you would like to set all matrices in x for each batch and each F to zero.
Is this the right understanding?

If so, you could try the following:

batch_size = 5
F = 7
N = 5

# Create m matrix and mask some values
m = torch.randn(N, N)
m[0, 4] = 0
m[2, 1] = 0
m[4, 3] = 0

# Get zero indices
mask_idx = (m == 0).nonzero()

# Create x and mask same indices
x = torch.randn(batch_size, N, N, F)
x[:, mask_idx[:, 0], mask_idx[:, 1], :] = 0

# Print some results
print(x[0, :, :, 0])
print(x[1, :, :, 0])
print(x[2, :, :, 3])
1 Like

Super, that’s precisely what I wanted to do!
And also is much less tricky than I implemented it before…
Thanks a lot!!! :slight_smile: :slight_smile: