Get the trace for a batch of matrices

Hi All,

I was wondering if it’s at all possible to take the trace of matrix for a batch of matrices? For example, let’s say I have some Tensor of shape [B, N, N] and wish to find the trace along B for each [N, N] matrix?

If I use the torch.trace command, I get an error saying that it expected a matrix. I understand I could just iterate over dim=0, and take the trace of each matrix then add them to a vector. But is there a way to do this with one command? Which takes [B, N, N] and returns [B]?

Thanks in advance! :slight_smile:

1 Like

You can do this via a combination of taking a (batch) diagonal and then summing each diagonal.

So:

B, N = 2, 3
x = torch.randn(B, N, N)
x.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

If you’re on a nightly build of PyTorch, this can be accomplished in one shot via torch.vmap. vmap essentially “adds a batch dimension to your code”:

B, N = 2, 3
x = torch.randn(B, N, N)
batch_trace = torch.vmap(torch.trace)(x)
3 Likes

Thank you for the extensive answer! It works as intended now! :slight_smile:

Also, if I could ask a quick question about torch.vmap, would such a function work with non-batchable functions like torch.autograd.functional.hessian? I’m currently using that function, and it’s a major bottleneck in the runtime of my code!

Thank you!

Extending an example from the docs, you can also use torch.einsum

b, n = 2, 3
torch.einsum("...ii", torch.randn(b, n, n))
1 Like