I’m compiling a simple for-loop, and get times 25 times slower than original version.
Before I start troubleshooting, I’m wondering if this kind of workload is even supposed to be fast under torch.compile
Code below implements training loop using batch-size=1, where the data is loaded in memory, so it works with data row-by-row from array. Pure Python version takes 784 ms on A100 and today’s nightly PyTorch version, compiled version is 20 seconds.
import time
class timeit:
def __init__(self, tag=""):
self.tag = tag
def __enter__(self):
if torch.cuda.is_available():
torch.cuda.synchronize()
self.start = time.perf_counter()
return self
def __exit__(self, *args):
if torch.cuda.is_available():
torch.cuda.synchronize()
self.end = time.perf_counter()
interval_ms = 1000 * (self.end - self.start)
print(f"{interval_ms:8.2f} {self.tag}")
import torch
import torch._dynamo as dynamo
# fix from https://discuss.pytorch.org/t/torch-dynamo-hit-config-cache-size-limit-64/183886
torch._dynamo.config.cache_size_limit = 1000000000
duration = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n = 200*784 # features
m = 10000 # datapoints
c = 200*10 # classes
bs = 1 # batch
lr = 1 # learning rate
X0 = torch.randn(m, n).to(device)
Y0 = torch.randn(m, c).to(device)
W0 = torch.randn(n, c).to(device)
def multiStepUpdate(W0, idx0):
for j in range(idx0, idx0 + duration):
idx = j % (m - bs + 1) # wrap around the end of dataset
a = X0[idx:idx + bs, :]
y = a @ W0
r = y - Y0[idx:idx + bs]
loss = 0.5 * (r**2).sum()/(bs * c)
g = a.T @ r / (bs * c)
normalizer = 1
W0 = W0 - lr * g * normalizer
return W0
with timeit('without compile'):
idx = 0
for i in range(10):
W = multiStepUpdate(W0, idx)
idx += duration
multiStepUpdate = torch.compile(multiStepUpdate)
with timeit('with compile'):
idx = 0
for i in range(10):
W = multiStepUpdate(W0, idx)
idx += duration
This prints
795.69 without compile
21064.64 with compile