Avoid loop unrolling with TorchDynamo for memory savings

I added a loop that breaks operations into sliced steps with the hope of reducing peak memory usage but TorchDynamo is unrolling the loop such that it no longer bounds memory allocation in the way I desire. Here is a somewhat simplified example of what I’m describing:

# Generate slices of input X and perform two linear transformations upon that slice (matmul with Wa and Wb),
# then produce an outer-product of A and B which is then cumulatively summed to the result.
# The output is the equivalent of (X @ Wa).transpose(-1, -2) @ (X @ Wb)
# however it should consume 1/num_slices as much transient memory.
def sliced_matmul_chain(X, Wa, Wb, num_slices):
    Xrows, Xcols = X.shape
    Xstep = Xrows / num_slices
    _, Acols = Wa.shape
    _, Bcols = Wb.shape
    C = torch.zeros(Acols, Bcols)
    for Xoff in range(0, Xrows, Xstep):
        Xi = X[Xoff:Xoff + Xstep, :]
        A = Xi @ Wa
        B = Xi @ Wb
        C += torch.einsum("ij,ik->jk", A, B) # i => Xstep size, j => Acols size, k => Bcols size
    return C

This loop ends up being unrolled such that it allocates memory equivalent to the full X @ Wa and X @ Wb intermediate. Are there ways to get TorchDynamo to not do this?