Compiling for-loop makes it run 25x slower than uncompiled version

I’m compiling a simple for-loop, and get times 25 times slower than original version.

Before I start troubleshooting, I’m wondering if this kind of workload is even supposed to be fast under torch.compile

Code below implements training loop using batch-size=1, where the data is loaded in memory, so it works with data row-by-row from array. Pure Python version takes 784 ms on A100 and today’s nightly PyTorch version, compiled version is 20 seconds.

import time
class timeit:
    def __init__(self, tag=""):
        self.tag = tag

    def __enter__(self):
        if torch.cuda.is_available():
        self.start = time.perf_counter()
        return self

    def __exit__(self, *args):
        if torch.cuda.is_available():
        self.end = time.perf_counter()
        interval_ms = 1000 * (self.end - self.start)
        print(f"{interval_ms:8.2f}   {self.tag}")
import torch
import torch._dynamo as dynamo
# fix from
torch._dynamo.config.cache_size_limit = 1000000000

duration = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'

n = 200*784   # features
m = 10000 # datapoints
c = 200*10    # classes
bs = 1    # batch 
lr = 1    # learning rate

X0 = torch.randn(m, n).to(device)
Y0 = torch.randn(m, c).to(device)
W0 = torch.randn(n, c).to(device)

def multiStepUpdate(W0, idx0):
    for j in range(idx0, idx0 + duration):
        idx = j % (m - bs + 1)  # wrap around the end of dataset
        a = X0[idx:idx + bs, :]
        y = a @ W0
        r = y - Y0[idx:idx + bs]
        loss = 0.5 * (r**2).sum()/(bs * c)
        g = a.T @ r / (bs * c)
        normalizer = 1
        W0 = W0 - lr * g * normalizer
    return W0

with timeit('without compile'):
    idx = 0
    for i in range(10):
        W = multiStepUpdate(W0, idx)
        idx += duration

multiStepUpdate = torch.compile(multiStepUpdate)

with timeit('with compile'):
    idx = 0
    for i in range(10):
        W = multiStepUpdate(W0, idx)
        idx += duration

This prints

  795.69   without compile
21064.64   with compile


1 Like

The issue here is that you are recompiling every time that you execute it, and then you are benchmarking the compilations as well. To see this, you can execute your script with TORCH_LOGS=recompiles.

Before benchmarking you should first compile your problem, To do that, you smply need to call the function with your args once. Note that torch.compile() itself does not compile anything.

When you don’t benchmark the compilation, I get that you go from 433 to 7 (in my RTX2060 making the sizes a bit smaller).

In other words, the issue behind this is torch.compile of simple loop takes 34 seconds · Issue #111441 · pytorch/pytorch · GitHub. See [WIP] [DO NOT MERGE] Convert loop bodies to function calls by Fidget-Spinner · Pull Request #113538 · pytorch/pytorch · GitHub for an effort to fix this in some cases.