Torch.compile performance with recursive code

Hi all,

For the past two years I’ve been using JAX. I now see that Pytorch introduced torch.compile and torch.vmap and wanted to try this. In JAX, I could wrap the loops in jax.lax.scan, such that the JIT compiler knows to compile without unrolling the loop, therefore ensuring low compilation times. I tried to compile a function containing a for loop with torch.compile (no data-dependent control flow, just iterating through an input array and returning another array).

def test_func(x_seq):
    result = torch.empty_like(x_seq)
    for i, (x,y) in enumerate(zip(x_seq, torch.arange(0,1,x_seq.shape[0]))):
        result[i] = x*y
    return result

The compilation time increases with the length of x_seq, which is unusable for me because my sequences can be very long. Is there a torch primitive to avoid this ?