Typically m << n. But for example, say with n=3, m=2, I want to reduce a 3x2 tensor to a 3 vector in this fashion:
I want to write the 3x2 tensor in this fashion
00 01 xx
xx 10 11
21 xx 20
where I’m only indicating the indices, xx indicates elements that don’t exist and then take a column sum to obtain a 3 vector
[00 + 21, 01 + 10, 11 + 20]
Or without the wrap-around, write it as
00 01 xx xx
xx 10 11 xx
xx xx 20 21
and again sum the columns to obtain
[00, 01 + 10, 11 + 20, 21][:-1]
seems almost like a conv1d-like operation or torch.diag/torch.diagonal operation
torch.sum( [ torch.roll( torch.diag( y[ :, i ] ), shifts=i, dims=1 ) for i in range( len( y[ 0 ] ) ], axis=0 ) ?
Thanks
ps: When I had spaces instead of xx in the original post, the spaces were eaten up and the displayed version was not how I typed it