Traceback (most recent call last):
File “/home/ikenaga/student-data/liuyiyu/Spy,without_training1234P1.1+C2.1+C2.2HDRVideo-HRWeightNet/main.py”, line 35, in
main(args)
File “/home/ikenaga/student-data/liuyiyu/Spy,without_training1234P1.1+C2.1+C2.2HDRVideo-HRWeightNet/main.py”, line 14, in main
model_compile = torch.compile(model)
File “/home/ikenaga/anaconda3/envs/lyy-2.0/lib/python3.9/site-packages/torch/init.py”, line 1441, in compile
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)
File “/home/ikenaga/anaconda3/envs/lyy-2.0/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py”, line 182, in call
assert callable(fn)
AssertionError
And here is the code:
import torch
from options import train_opts
from utils import logger, recorders, train_utils, test_utils
from datasets import custom_dataloader
from models import build_model
args = train_opts.TrainOpts().parse()
log = logger.Logger(args)
def main(args):
model = build_model(args, log)
model_compile = torch.compile(model)
recorder = recorders.Records(records=None)
train_loader, val_loader = custom_dataloader(args, log)
for epoch in range(args.start_epoch, args.epochs+1):
model_compile.update_learning_rate()
recorder.insert_record('train', 'lr', epoch, model_compile.get_learning_rate())
train_utils.train(args, log, train_loader, model_compile, epoch, recorder)
if epoch == 1 or (epoch % args.save_intv == 0):
model_compile.save_checkpoint(epoch, recorder.records)
log.plot_all_curves(recorder, 'train')
if epoch % args.val_intv == 0:
test_utils.test(args, log, 'val', val_loader, model_compile, epoch, recorder)
log.plot_all_curves(recorder, 'val')
if name == ‘main’:
torch.manual_seed(args.seed)
main(args)