Is there a way to compute matrix trace in batch broadcast fashion?

For example, I have a tensor of size(batch_size, hidden_size, hidden_size). Is there a way to compute the trace of matrix(hidden_size, hidden_size) for every sample in this batch without a loop? So that the output tensor is a vector of size(batch_size).

Any tips will help! Thanks!

4 Likes

Did you find a solution? (apart from looping) I was looking for a way to do this operation but it seems there is no bathed trace function

There are several ways to get this easily, for example

b = torch.einsum('bii->b', a)

or

b2 = torch.diagonal(a, dim1=-2, dim2=-1).sum(-1) 

Best regards

Thomas

7 Likes