Training under Compile


I’ve been attempting to build some code that leverages the Pytorch 2.0 compile module. However, whenever I attempt to build any compiled training routine there’s no actual updating of the model weights. I’ve tried this with both direct compile calls of the model itself and compile calls of each step of the training loop. However, in both cases the compilation leads to a constant evaluated loss, whereas disabling the compilation shows convergence.

To try and recreate this in its most basic form, I went back to the Torch 2.0 compile tutorial ( torch.compile Tutorial — PyTorch Tutorials 2.0.0+cu117 documentation ), and have built a minimum viable reproduction based upon that.

import torch
from torchvision.models import resnet18
import numpy as np
from torchvision.models import resnet18

def timed(fn):

def generate_data(b):
    a = torch.randn(b, 3, 128, 128).to(torch.float32).cuda()
    return (
        0.5 + 0*a[:, 0, 0, 0].reshape(-1, 1).to(torch.float32).cuda() # Same effect without the 0*

class MyModule(torch.nn.Module):
    def __init__(self):
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        y = x + 0.1*torch.randn_like(x)
        return torch.nn.functional.relu(self.lin(y))
mod = MyModule()
opt_mod = torch.compile(mod)

def init_model():
    return resnet18(num_classes=1).to(torch.float32).cuda()
def evaluate(mod, inp):
    inp = inp.reshape(16, 1, 1, 1)
    inp = inp.repeat(1,3,1,1)
    return mod(inp)

model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    pred = mod(data[0])
    loss = torch.nn.MSELoss()(pred, data[1]) # Have also tested with a CrossEntropy variant
    return np.log10(loss.item())

N_ITERS = 1000

eager_times = []
eager_losses = []
for i in range(N_ITERS):
    inp = generate_data_2(64)
    lv, eager_time = timed(lambda: train(model, inp))
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
compile_losses = []
for i in range(N_ITERS):
    inp = generate_data_2(64)
    lv, compile_time = timed(lambda: train_opt(model, inp))
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x", flush=True)
print("~" * 10)

comp_array, eager_array = np.asarray(compile_losses), np.asarray(eager_losses)
print(f'{np.mean(comp_array[10:200])} {np.mean(comp_array[N_ITERS - 200:N_ITERS - 10])}', flush=True)
print(f'{np.mean(eager_array[10:200])} {np.mean(eager_array[N_ITERS - 200:N_ITERS - 10])}', flush=True)

print(f'{comp_array}', flush=True) # Nominally constant, but with some slight variation
print('#'*100, flush=True)
print(f'{eager_array}', flush=True) # Clear demonstrated convergence

The compiled model will produce nominally constant losses, whereas the eager variant rapidly converges to loses that are O(10^-4).

Given that this is a very lightly modified variant of the torch.compile tutorial, I was hoping that somebody could point out to me what I’m missing. Is it as simple as pass by reference not working through torch.compile?

Thanks in advance

Could you remove the mode="reduce-overhead" usage and see if this would fix the issue?
I’ve reported this issue a while ago in this comment and based on the latest updates it might have been fixed already in the current nightly release, so you could try to check it, too.

Hi ptrblck - thanks for your comment, and apologies on the delay in getting back to you. I can confirm that running it as either default or max-autotune does result in the expected behaviour on 2.0.0+cu117. Appreciate the knowledge you provide on these forums.

I might post this up as another thread to keep things distinguished, but I’m running into an issue with compile that the last batch from a dataloader triggers an error of “AttributeError: ‘UnspecializedNNModuleVariable’ object has no attribute ‘module_key’”, which kills the code without any reference as to what exactly is causing the issue. Is there any documentation I can follow to understand what exactly is triggering this issue, and why it only triggers during the last batch? I had assumed it was because the last batch was of a different size, but I’ve tried both padding and drop_last in the dataloader, and neither has been successful.

The module_key error should come from the dynamo stack.
Could you post the full stacktrace here?
Also, are you able to reproduce the same error by using the last batch “standalone”, i.e. outside of the DataLoader loop?

Apologies for the delay in getting back to you - I ended up refactoring the code a fair bit and have no longer been experiencing the issue. Thanks again for your assistance, it’s always appreciated.