Assertion Error with torch.compile

model_compiled = torch.compile(model)

File “/home/anaconda3/envs/python3.9/site-packages/torch/init.py”, line 1441, in compile
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)
File “/home/anaconda3/envs/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py”, line 182, in __call__assert callable(fn)
AssertionError.

And my version is as follows :
nvidia-driver-515, cuda=11.7,python=3.9,torch=2.0

Do you have a full repro?

Traceback (most recent call last):
File “/home/ikenaga/student-data/liuyiyu/Spy,without_training1234P1.1+C2.1+C2.2HDRVideo-HRWeightNet/main.py”, line 35, in
main(args)
File “/home/ikenaga/student-data/liuyiyu/Spy,without_training1234P1.1+C2.1+C2.2HDRVideo-HRWeightNet/main.py”, line 14, in main
model_compile = torch.compile(model)
File “/home/ikenaga/anaconda3/envs/lyy-2.0/lib/python3.9/site-packages/torch/init.py”, line 1441, in compile
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)
File “/home/ikenaga/anaconda3/envs/lyy-2.0/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py”, line 182, in call
assert callable(fn)
AssertionError

And here is the code:
import torch
from options import train_opts
from utils import logger, recorders, train_utils, test_utils
from datasets import custom_dataloader
from models import build_model

args = train_opts.TrainOpts().parse()
log = logger.Logger(args)

def main(args):

model = build_model(args, log)
model_compile = torch.compile(model)
recorder = recorders.Records(records=None)

train_loader, val_loader = custom_dataloader(args, log)

for epoch in range(args.start_epoch, args.epochs+1):
    model_compile.update_learning_rate()
    recorder.insert_record('train', 'lr', epoch, model_compile.get_learning_rate())
    
    train_utils.train(args, log, train_loader, model_compile, epoch, recorder)
    if epoch == 1 or (epoch % args.save_intv == 0): 
        model_compile.save_checkpoint(epoch, recorder.records)
    log.plot_all_curves(recorder, 'train')

    if epoch % args.val_intv == 0:
        test_utils.test(args, log, 'val', val_loader, model_compile, epoch, recorder)
        log.plot_all_curves(recorder, 'val')

if name == ‘main’:
torch.manual_seed(args.seed)
main(args)

I’m sorry I still need a fully functional minimal repro, as in something I can copy paste in google colab and reproduce the same error you’re seeing

I can guess that the problem is save_checkpoint() you would need to checkpoint the original model but I’m just guessing without a repro