Issues with PyTorch while_loop operator when exporting the torch model to ExecuTorch

Hi,
I am recently working on a model based on “state space representation” using PyTorch.
I was able to build the torch model and it is working as expected.
Now, I wanted to convert the model to ExecuTorch IR using torch.export.export(), but during export(), I have faced some issues, where the export() call “freezes”.

Then, I was able to root cause it and it is occurring due to a “for loop” present in the forward() call of the model whose no of iterations depend on the model input shape (in my case it is 16000, so should iterate over 16000 times) and due to the loop unrolling, the export call is freezing due to memory issues.

In order to fix that, I came to know about PyTorch while_loop operator which is supposed to be better in handling the loops.
So, in my model I have implemented it and I can see the export() call is working!
But in the next stages of convertig the model to “Edge IR”, I am observing some “mutation/aliasing issues”, which I couldn’t crack.

So, I have created a simple PyTorch script to reproduce the issue

import torch
from torch import nn
from torch.export import export
from executorch.exir import to_edge_transform_and_lower

class ExportableLoop(nn.Module):
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.hidden_size = hidden_size
        self.B = nn.Parameter(torch.randn(hidden_size, 1))  # (H, in_channels)
        self.C = nn.Parameter(torch.randn(out_channels, hidden_size))  # (C_out, H)
        A = torch.randn(2, hidden_size)
        self.A_real = nn.Parameter(A[0])
        self.A_imag = nn.Parameter(A[1])

    def update_state(self, h, x_t):
        # h: [B, 2, H], x_t: [B, H]
        hr, hi = h[:, 0, :], h[:, 1, :]  # [B, H]
        hrn = hr * self.A_real - hi * self.A_imag + x_t  # [B, H]
        hin = hi * self.A_real + hr * self.A_imag        # [B, H]
        hn = torch.stack([hrn, hin], dim=1)              # [B, 2, H]
        return hn, hrn

    def forward(self, u):
        # u: [B, 1, T]
        x = torch.matmul(self.B, u)  # (B, H, T)
        B, H, T = x.shape

        h = torch.zeros(B, 2, H, device=x.device, dtype=x.dtype)  # [B, 2, H]
        h_accum = torch.zeros(B, H, T, device=x.device, dtype=x.dtype)  # [B, H, T]
        i = torch.tensor(0, device=x.device, dtype=torch.int64)

        def cond(i, h, h_accum):
            i_next, _, _ = i, h, h_accum
            return i_next < T

        def body(i, h, h_accum):
            x_t = x.index_select(-1, i.unsqueeze(0)).squeeze(-1)  # ✅ safe for export
            h, hr = self.update_state(h, x_t)  # h: [B, 2, H], hr: [B, H]
            h_accum = h_accum.index_copy(-1, i.unsqueeze(0), hr.unsqueeze(-1))  # [B, H, T]
            return i + 1, h, h_accum

        _, h, h_accum = torch._higher_order_ops.while_loop(cond, body, (i, h, h_accum))
        y = torch.matmul(self.C, h_accum).transpose(0, 1)  # (B, C_out, T)
        return y

# Instantiate and export
model = ExportableLoop(hidden_size=128, out_channels=10)
inp = torch.randn(1, 1, 32)  # (B, in_channels=1, T=32)
exported = export(model, (inp,))
print("Exporting Done...")
executorch_program = to_edge_transform_and_lower(
    exported
)
print("Edge transform Done...")

out_path = "loop_model.pte"
with open(out_path, "wb") as file:
    file.write(executorch_program.buffer)

print(f"Succesfully saved model as {out_path}")

And the issue I am facing is :

torch._higher_order_ops.utils.UnsupportedAliasMutationException: torch.while_loop's cond_fn might be aliasing the input!

While executing %while_loop : [num_users=3] = call_function[target=torch.ops.higher_order.while_loop](args = (%while_loop_cond_graph_0, %while_loop_body_graph_0, (%detach_, %zeros, %zeros_1), (%a_imag, %a_real, %matmul)), kwargs = {})

I have posted this in ExecuTorch forum who suggested me to ask here in PyTorch community.

So, can someone please help me in resolving this issue?

Thanks!

Hi,

Any help or insights on this will be helpfull, so that I can proceed further in exporting the model to ExecuTorch.

Thanks!