a = torch.randn(2, 5, 5)
result = []
for i in range(a.size(0)):
result.append(torch.diag(a[i]))
result = torch.stack(result)
# I hope
result = torch.diag(a) # size of result is `2*5`
a = torch.randn(2, 5, 5)
result = []
for i in range(a.size(0)):
result.append(torch.diag(a[i]))
result = torch.stack(result)
# I hope
result = torch.diag(a) # size of result is `2*5`