I have two matrices, `A,B`

.

A is a normal map, with shape: [b,h,w,3] and B is a transformation matrix with shape [b,4,4,3], where b is for batch_size and 3 for the number of channels.

I want to implement the rendering equation by: (A.T *B *A) . The output should be a matrix that should be multiplied afterwards with the texture(b,h,w,3) so it would have to be same or HDR. I tried doing the aforementioned multiplications via torch.bmm() but failed. Any idea?