Typically m << n. But for example, say with n=3, m=2, I want to reduce a 3x2 tensor to a 3 vector in this fashion:

I want to write the 3x2 tensor in this fashion

00 01 xx

xx 10 11

21 xx 20

where I’m only indicating the indices, xx indicates elements that don’t exist and then take a column sum to obtain a 3 vector

[00 + 21, 01 + 10, 11 + 20]

Or without the wrap-around, write it as

00 01 xx xx

xx 10 11 xx

xx xx 20 21

and again sum the columns to obtain

[00, 01 + 10, 11 + 20, 21][:-1]

seems almost like a conv1d-like operation or torch.diag/torch.diagonal operation

torch.sum( [ torch.roll( torch.diag( y[ :, i ] ), shifts=i, dims=1 ) for i in range( len( y[ 0 ] ) ], axis=0 ) ?

Thanks

ps: When I had spaces instead of xx in the original post, the spaces were eaten up and the displayed version was not how I typed it