Jit.script cause bad graph conversion from Torch to Onnx

hello, I wrote this Torch script to populate a correlation matrix:

@torch.jit.script
def populate_corr(corr: torch.Tensor,
input1: torch.Tensor,
input2: torch.Tensor,
max_displacement: Tuple[int, int],
stride: Tuple[int, int],
dilation_patch: Tuple[int, int],
h: int,
w: int):

for i in range(0, 2*max_displacement[0]+1, dilation_patch[0]):
    for j in range(0, 2*max_displacement[1]+1, dilation_patch[1]):
        p2 = input2[:, :, i:i+h, j:j+w]
        p2 = p2[:, :, ::stride[0], ::stride[1]]
        corr[:, i//dilation_patch[0], j//dilation_patch[1]] = (input1 * p2).sum(dim=1)
return corr

When I export the torch model to onnx model size is kept and also inference is the same but exploring the onnx model in netron shows that the resulting graph inside the converted the onnx model is very convoluted and full of complex subgraphs caused by such loops, this is translated in poor speed performances once the model is deployed on end device (Torch → onnx → TF → tflite quantized).

Is there a more efficient way to write such script in order to obtain a more optimized graph?