Create diagonal matrices from batch

I have a batch

x = torch.rand(size=(M, N))

and want to create for each of the M inputs a diagonal matrix with dimensions N x N such that the output has dimensions M x N x N. How can I do that? If I pass x to torch.diag I get a one-dimensional output.

Any idea what I do wrong?

Hi Samuel!

Try:

torch.diag_embed (torch.rand (size = (M, N)))

Best.

K. Frank

2 Likes