Torch.diag is not batch-paralleled

a = torch.randn(2, 5, 5)

result = []
for i in range(a.size(0)):
    result.append(torch.diag(a[i]))
result = torch.stack(result)

# I hope 
result  = torch.diag(a) # size of result is `2*5`