I want to create a mask matrix through torch.block_diag() method. However, each diagonal block has a different shape.
torch.block_diag(torch.ones((2,2)), torch.ones((3,3)), torch.ones((2,2)))
Please can I know if there is an efficient way to achieve this?