# Get the trace for a batch of matrices

Hi All,

I was wondering if it’s at all possible to take the trace of matrix for a batch of matrices? For example, let’s say I have some Tensor of shape `[B, N, N]` and wish to find the trace along `B` for each `[N, N]` matrix?

If I use the `torch.trace` command, I get an error saying that it expected a matrix. I understand I could just iterate over dim=0, and take the trace of each matrix then add them to a vector. But is there a way to do this with one command? Which takes `[B, N, N]` and returns `[B]`?

Thanks in advance! You can do this via a combination of taking a (batch) diagonal and then summing each diagonal.

So:

``````B, N = 2, 3
x = torch.randn(B, N, N)
x.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
``````

If you’re on a nightly build of PyTorch, this can be accomplished in one shot via torch.vmap. `vmap` essentially “adds a batch dimension to your code”:

``````B, N = 2, 3
x = torch.randn(B, N, N)
batch_trace = torch.vmap(torch.trace)(x)
``````
2 Likes

Thank you for the extensive answer! It works as intended now! Also, if I could ask a quick question about `torch.vmap`, would such a function work with non-batchable functions like `torch.autograd.functional.hessian`? I’m currently using that function, and it’s a major bottleneck in the runtime of my code!

Thank you!