Efficient way to zero out certain diagonals

I have a huge matrix A. I want to zero diagonals 2,-2 and 4. Please may I know an efficient way to do this using pytorch?

Hi Adonai!

You may use torch.diagonal() to get a modifiable view into your matrix of
the desired diagonals, and then use zero_() to zero them out:

>>> torch.__version__
'1.7.1'
>>> A = torch.arange (36).view ((6, 6))
>>> A
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])
>>> A.diagonal (2).zero_()
tensor([0, 0, 0, 0])
>>> A.diagonal (-2).zero_()
tensor([0, 0, 0, 0])
>>> A.diagonal (4).zero_()
tensor([0, 0])
>>> A
tensor([[ 0,  1,  0,  3,  0,  5],
        [ 6,  7,  8,  0, 10,  0],
        [ 0, 13, 14, 15,  0, 17],
        [18,  0, 20, 21, 22,  0],
        [24, 25,  0, 27, 28, 29],
        [30, 31, 32,  0, 34, 35]])

Best.

K. Frank

1 Like