How to calculate matrix trace in 3d tensor

How to calculate the trace of a matrix in 3d tensor where first dimension is the batch size?
torch.trace only works on 2d input but I want it to run on my batch tensor to get output of shape (batch_size,1)


You can define a mask (a matrix of only 0 & 1). Multiply the mask with the matrix you have & then take the sum. Example:

input_matrix  = torch.ones((4, 5, 5)) #batch_size is 4
mask = torch.zeros((4, 5, 5))
mask[:, torch.arange(0,5), torch.arange(0,5) ] = 1.0 #This will mask all non-diagonal values.

output = input_matrix * mask
output = torch.sum(output,axis=(1,2))  #output will be of dimension (batch) & will contain the trace of each 5 x 5 matrix.


If you have a 3D tensor where the first dimension is the batch dimension the following will work

B=1000 #batch size
A=16    #nrows/ncols
matrices = torch.randn(B,A,A) #batch of matrices
trace = matrices.diagonal(offset=0, dim1=-2, dim2=-1).sum(dim=-1) #grab diagonal along last 2 dims, and sum