Hello!
I try to use torch.complie()
function, but when testing this function with my train_epoch and elav_epoch, the following error appears:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
323 try:
--> 324 out_code = transform_code_object(code, transform)
325 orig_code_map[out_code] = code
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445, in transform_code_object(code, transformations, safe)
443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:299, in _compile.<locals>.transform(instructions, code_options)
298 nonlocal output
--> 299 tracer = InstructionTranslator(
300 instructions,
301 code,
302 locals,
303 globals,
304 builtins,
305 code_options,
306 compiler_fn,
307 one_graph,
308 export,
309 mutated_closure_cell_contents,
310 )
311 tracer.run()
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1670, in InstructionTranslator.__init__(self, instructions, f_code, f_locals, f_globals, f_builtins, code_options, compiler_fn, one_graph, export, mutated_closure_cell_contents)
1668 vars.extend(x for x in self.cell_and_freevars() if x not in vars)
-> 1670 self.symbolic_locals = collections.OrderedDict(
1671 (
1672 k,
1673 VariableBuilder(
1674 self,
1675 LocalInputSource(k, code_options["co_varnames"].index(k))
1676 if k in code_options["co_varnames"]
1677 else LocalSource((k)),
1678 )(f_locals[k]),
1679 )
1680 for k in vars
1681 if k in f_locals
1682 )
1684 # symbolic_locals contains the mapping from original f_locals to the
1685 # Variable objects. During the Variable building phase, each object also
1686 # has its associated guards. At the end, we will accumulate these
(...)
1699 # next invocation when args is not a list, and args[0] is a runtime
1700 # error. Therefore, we recursively add guards for list/dict variable here.
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1673, in <genexpr>(.0)
1668 vars.extend(x for x in self.cell_and_freevars() if x not in vars)
1670 self.symbolic_locals = collections.OrderedDict(
1671 (
1672 k,
-> 1673 VariableBuilder(
1674 self,
1675 LocalInputSource(k, code_options["co_varnames"].index(k))
1676 if k in code_options["co_varnames"]
1677 else LocalSource((k)),
1678 )(f_locals[k]),
1679 )
1680 for k in vars
1681 if k in f_locals
1682 )
1684 # symbolic_locals contains the mapping from original f_locals to the
1685 # Variable objects. During the Variable building phase, each object also
1686 # has its associated guards. At the end, we will accumulate these
(...)
1699 # next invocation when args is not a list, and args[0] is a runtime
1700 # error. Therefore, we recursively add guards for list/dict variable here.
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:172, in VariableBuilder.__call__(self, value)
171 return self.tx.output.side_effects[value]
--> 172 return self._wrap(value).clone(**self.options())
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:446, in VariableBuilder._wrap(self, value)
437 return NumpyVariable(
438 value,
439 source=self.source,
(...)
444 ),
445 )
--> 446 elif value in tensor_dunder_fns:
447 return TorchVariable(
448 value,
449 source=self.source,
450 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
451 )
File /opt/conda/lib/python3.10/site-packages/tqdm/utils.py:75, in Comparable.__eq__(self, other)
74 def __eq__(self, other):
---> 75 return self._comparable == other._comparable
AttributeError: 'function' object has no attribute '_comparable'
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
The above exception was the direct cause of the following exception:
InternalTorchDynamoError Traceback (most recent call last)
Cell In[47], line 6
3 best_valid_loss = 1e12
5 for epoch in range(CONFIG['epochs']):
----> 6 train_loss = train_opt(model, optimizer, criterion, scheduler, train_dataloader)
7 eval_loss = eval_opt(model, criterion, valid_dataloader)
9 if valid_loss < best_valid_loss:
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
207 dynamic_ctx.__enter__()
208 try:
--> 209 return fn(*args, **kwargs)
210 finally:
211 set_eval_frame(prior)
Cell In[34], line 4, in train_epoch(model, optimizer, criterion, scheduler, dataloader)
3 def train_epoch(model, optimizer, criterion, scheduler, dataloader):
----> 4 model.train()
5 running_loss = 0.0
7 for src, trg in tqdm(dataloader):
Cell In[34], line 7, in <graph break in train_epoch>(___stack0, model, optimizer, criterion, scheduler, dataloader)
4 model.train()
5 running_loss = 0.0
----> 7 for src, trg in tqdm(dataloader):
8 src = src.to(device)
9 trg = trg.to(device) # (batch_size, seq_len)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
334 return hijacked_callback(frame, cache_size, hooks)
336 with compile_lock:
--> 337 return callback(frame, cache_size, hooks)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:404, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
402 counters["frames"]["total"] += 1
403 try:
--> 404 result = inner_convert(frame, cache_size, hooks)
405 counters["frames"]["ok"] += 1
406 return result
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104, in wrap_convert_context.<locals>._fn(*args, **kwargs)
102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
103 try:
--> 104 return fn(*args, **kwargs)
105 finally:
106 torch._C._set_grad_enabled(prior_grad_mode)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
259 global initial_grad_state
260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
263 frame.f_code,
264 frame.f_globals,
265 frame.f_locals,
266 frame.f_builtins,
267 compiler_fn,
268 one_graph,
269 export,
270 hooks,
271 frame,
272 )
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:163, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
161 compilation_metrics[key] = []
162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
164 time_spent = time.time() - t0
165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:394, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
392 except Exception as e:
393 exception_handler(e, code, frame)
--> 394 raise InternalTorchDynamoError() from e
InternalTorchDynamoError:
I use the following code:
train_opt = torch.compile(train_epoch, mode="reduce-overhead")
eval_opt = torch.compile(eval_epoch, mode="reduce-overhead")
#!g1.1
# num_epochs = 3
best_valid_loss = 1e12
for epoch in range(CONFIG['epochs']):
train_opt = train_opt(model, optimizer, criterion, scheduler, train_dataloader)
eval_opt = eval_opt(model, criterion, valid_dataloader)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(path, model_checkpoint))
print(f"Model is saved to {os.path.join(path, model_checkpoint)}")
print(f"Epoch №{epoch + 1}:")
print(f"Training Loss: {train_loss}")
print(f"Validation Loss: {valid_loss}")
print()
Train and eval funcs:
#!g1.1
# add Train Loop
#!g1.1
# num_epochs = 3
best_valid_loss = 1e12
for epoch in range(CONFIG['epochs']):
train_loss = train_epoch(model, optimizer, criterion, scheduler, train_dataloader)
valid_loss = eval_epoch(model, criterion, valid_dataloader)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(path, model_checkpoint))
print(f"Model is saved to {os.path.join(path, model_checkpoint)}")
print(f"Epoch №{epoch + 1}:")
print(f"Training Loss: {train_loss}")
print(f"Validation Loss: {valid_loss}")
print()