Torch.diag for batches?

How to use torch.diag for batches?

I’d use torch.diagonal and torch.diag_embed instead.

Best regards

Thomas