What is the batched version torch.diagflat?

So, I am trying to modify some of my code so that it supports the batch dimension. One of the lines uses torch.diagflat, and I was wondering what would be the batched version of it?

I see there are the torch.diag and torch.diagonal functions, but it’s not clear if they replicate torch.diagflat?

import torch
x = torch.randn(2, 3) # batch of 2
print(torch.diagflat(x).shape)
# size is torch.Size([6, 6]) instead of torch.Size([2, 3, 3])

You can use torch.diag_embed with torch.view(batch_size, -1) as the input. If your tensor is not necessarily contiguous, you can use torch.resize instead of torch.view.

Diagonals seems to be one of the bits of the numpy api that isn’t thought out terribly well w.r.t. being flexible/intuitive for use-cases like batching (see the default dimension behaviour for numpy.diagonal always seemed odd to me, too)…

Best regards

Thomas

Thank you for the help!