Dear all,
I have a matrix (B, 3N), where the rows are the elements in the batch and the columns correspond to N points in a 3D space, like (x1,y1,z1,x2,y2,z2,…,xN,yN,zN). I can reshape it to be (B,N,3).
Now I want to apply a rotation to the dataset, projecting each 3D point on three versors e_x, e_y, e_z. Each of them is a matrix of shape (B,3) since the reference frame is different for each element in the batch.
I have implemented this with a for-loop over the batches and each atom, in the following way:
# x.shape => (B,N,3)
# e_x.shape => (B,3) {also e_y and e_x}
x_new = torch.zeros((B,N,3))
for b in range(B): # loop over batch
for j in range(N): # loop over atoms
x_new[b,j,0]=torch.dot(x[b,j,:], e_x[b,:]) # x component of atom j in the new reference frame
x_new[b,j,1]=torch.dot(x[b,j,:], e_y[b,:]) # y component
x_new[b,j,2]=torch.dot(x[b,j,:], e_z[b,:]) # z component
but of course it is really slow. Do you have any suggestion about how to vectorize this calculation avoiding the for loops? Should I concatenate e_x, e_y, e_z in a batch of matrices and then apply it to the data?
Thank you very much,
Luigi