Hi,

I’ve been attempting to build some code that leverages the Pytorch 2.0 compile module. However, whenever I attempt to build any compiled training routine there’s no actual updating of the model weights. I’ve tried this with both direct compile calls of the model itself and compile calls of each step of the training loop. However, in both cases the compilation leads to a constant evaluated loss, whereas disabling the compilation shows convergence.

To try and recreate this in its most basic form, I went back to the Torch 2.0 compile tutorial ( torch.compile Tutorial — PyTorch Tutorials 2.0.0+cu117 documentation ), and have built a minimum viable reproduction based upon that.

```
import torch
from torchvision.models import resnet18
import numpy as np
from torchvision.models import resnet18
def timed(fn):
...
def generate_data(b):
a = torch.randn(b, 3, 128, 128).to(torch.float32).cuda()
return (
a,
0.5 + 0*a[:, 0, 0, 0].reshape(-1, 1).to(torch.float32).cuda() # Same effect without the 0*
)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
y = x + 0.1*torch.randn_like(x)
return torch.nn.functional.relu(self.lin(y))
mod = MyModule()
opt_mod = torch.compile(mod)
def init_model():
return resnet18(num_classes=1).to(torch.float32).cuda()
def evaluate(mod, inp):
inp = inp.reshape(16, 1, 1, 1)
inp = inp.repeat(1,3,1,1)
inp.requires_grad_()
return mod(inp)
model = init_model()
opt = torch.optim.Adam(model.parameters())
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.MSELoss()(pred, data[1]) # Have also tested with a CrossEntropy variant
loss.backward()
opt.step()
return np.log10(loss.item())
N_ITERS = 1000
eager_times = []
eager_losses = []
for i in range(N_ITERS):
inp = generate_data_2(64)
lv, eager_time = timed(lambda: train(model, inp))
eager_times.append(eager_time)
print(f"eager train time {i}: {eager_time}")
eager_losses.append(lv)
print("~" * 10)
model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")
compile_times = []
compile_losses = []
for i in range(N_ITERS):
inp = generate_data_2(64)
lv, compile_time = timed(lambda: train_opt(model, inp))
compile_times.append(compile_time)
print(f"compile train time {i}: {compile_time}")
compile_losses.append(lv)
print("~" * 10)
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x", flush=True)
print("~" * 10)
comp_array, eager_array = np.asarray(compile_losses), np.asarray(eager_losses)
print(f'{np.mean(comp_array[10:200])} {np.mean(comp_array[N_ITERS - 200:N_ITERS - 10])}', flush=True)
print(f'{np.mean(eager_array[10:200])} {np.mean(eager_array[N_ITERS - 200:N_ITERS - 10])}', flush=True)
print(f'{comp_array}', flush=True) # Nominally constant, but with some slight variation
print('#'*100, flush=True)
print(f'{eager_array}', flush=True) # Clear demonstrated convergence
```

The compiled model will produce nominally constant losses, whereas the eager variant rapidly converges to loses that are O(10^-4).

Given that this is a very lightly modified variant of the torch.compile tutorial, I was hoping that somebody could point out to me what I’m missing. Is it as simple as pass by reference not working through torch.compile?

Thanks in advance