I have a simple issue but I am not sure how to deal with.

Let’s say I have a tensor (B, N, N, F) where B is the Batch Size, N the dimensions of a square matrix and F the dimension of my feature space. What is the most straight forward way to mask the diagonals with 0?

EDIT: Sorry, I just realized, this formula might lead to wrong results, when the dimension sizes differ.
What is your definition of diagonal? Each diagonal of the square matrix [N, N]?

Yes, I meant each diagonal on my F [N,N] matrices for each batch.

But let’s look at the question another way. Let’s see I have a matrix M=[N,N] with zeros somewhere in there, that I want to use as a mask for all the F matrices. The purpose is to set to zero the element where in every matrix for the same indices where my matrix M has zeros. Is there an efficient way to do that in pytorch?

I hope I explained it well, otherwise I will try to address the question better

As far as I understand you problem, you have two matrices. Let’s name them m = [N, N] and x = [batch_size, N, N, F].
Somewhere in m are zeros and you would like to get the indices.
Using these indices you would like to set all matrices in x for each batch and each F to zero.
Is this the right understanding?

If so, you could try the following:

batch_size = 5
F = 7
N = 5
# Create m matrix and mask some values
m = torch.randn(N, N)
m[0, 4] = 0
m[2, 1] = 0
m[4, 3] = 0
print(m)
# Get zero indices
mask_idx = (m == 0).nonzero()
# Create x and mask same indices
x = torch.randn(batch_size, N, N, F)
x[:, mask_idx[:, 0], mask_idx[:, 1], :] = 0
# Print some results
print(x[0, :, :, 0])
print(x[1, :, :, 0])
print(x[2, :, :, 3])