Set diagonal of each matrix in a batch to 0

I have a tensor of shape N * M * M, where N is the batch size. I want to set the diagonal of each M*M matrix to a specific number (e.g. 0). I tried fill_diagonal_ method but it doesn’t work if N!=M.

I can loop through the batch dimension, but doesn’t want to create a loop.

Help would be appreciated.

2 Likes
N, M, value = 5, 3, 3.1415926
x = torch.rand(N, M, M)
mask = torch.eye(M).repeat(N, 1, 1).bool()
x[mask] = value
1 Like

Thank you for the help.

A simpler way using less memory:

N, M, value = 5, 3, 0
x = torch.rand(N, M, M)
i = torch.arange(M)
x[:, i, i] = value