Optimize for loops in torch.jit

Here is a code that is super slow (100ms !!!):

Basically it is a for loop with super simple dot products over vectors of 3 components.

How to optimize this? should i compute the backward myself instead?