How to use torch.diag for batches?
I’d use torch.diagonal and torch.diag_embed instead.
Best regards
Thomas