Hi everyone,

I am wondering how to replace a block diagonal matrix with other values in a tensor in a batched manner. For example, given a tensor with shape [B, 3, 6], where B is the batch size (here B = 1 for illustration):

[[[ 0, 1, 2, 3, 4, 5],

[ 6, 7, 8, 9, 10, 11],

[12, 13, 14, 15, 16, 17]]]

Now I would like to replace each 1 x 2 block along the diagonal with a new value, say [1,1], which transforms the tensor above to

[[[ **1, 1**, 2, 3, 4, 5],

[ 6, 7, **1, 1**, 10, 11],

[12, 13, 14, 15, **1, 1**]]]

Is it a must to achieve this through a loop or I just ignored some efficient APIs in PyTorch?

Thank you very much!