Hi All,
I was wondering if it’s at all possible to take the trace of matrix for a batch of matrices? For example, let’s say I have some Tensor of shape [B, N, N]
and wish to find the trace along B
for each [N, N]
matrix?
If I use the torch.trace
command, I get an error saying that it expected a matrix. I understand I could just iterate over dim=0, and take the trace of each matrix then add them to a vector. But is there a way to do this with one command? Which takes [B, N, N]
and returns [B]
?
Thanks in advance!
1 Like
You can do this via a combination of taking a (batch) diagonal and then summing each diagonal.
So:
B, N = 2, 3
x = torch.randn(B, N, N)
x.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
If you’re on a nightly build of PyTorch, this can be accomplished in one shot via torch.vmap. vmap
essentially “adds a batch dimension to your code”:
B, N = 2, 3
x = torch.randn(B, N, N)
batch_trace = torch.vmap(torch.trace)(x)
3 Likes
Thank you for the extensive answer! It works as intended now!
Also, if I could ask a quick question about torch.vmap
, would such a function work with non-batchable functions like torch.autograd.functional.hessian
? I’m currently using that function, and it’s a major bottleneck in the runtime of my code!
Thank you!
Extending an example from the docs, you can also use torch.einsum
b, n = 2, 3
torch.einsum("...ii", torch.randn(b, n, n))
1 Like