ATen: how to access rows and columns of matrices

Hello! I have two matrices Q and U and want to the inner product of their j-th columns. The naive implementation is below.

double sum = 0.0
for (int idx = 0; idx < m; ++idx) {
    sum += *(Q[idx][j] * U[idx][j]).data<double>();
}

I can also use the dot product function, but I can’t understand how to get the j-th column of matrix. Thank you in advance!

I think this works:

at::dot(Q.slice(1, j, j+1), U.slice(1, j, j+1));

but I’m not at a machine right now where I can try it

1 Like