While waiting for this to get resolved:
I stumbled on this:
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu117
pip uninstall pytorch-triton
pip install --no-deps triton==2.0.0.a2
TRITON_PTXAS_PATH=/usr/bin/ptxas python my_model_compiler.py
and everything started working 
BUT then I ran into a strange issue here:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, n_embd=768, bias=False):
super().__init__()
self.n_head = 6
self.embd = nn.Embedding(50257, 768)
self.c_attn = nn.Linear(in_features=n_embd, out_features=3 * n_embd, bias=bias)
self.dropout = 0
def forward(self, x):
x = self.embd(x)
(B, T, C) = x.size()
q, k, v = self.c_attn(x).chunk(chunks=3, dim=-1)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = torch.nn.functional.scaled_dot_product_attention(
query=q, key=k, value=v, attn_mask=None, dropout_p=self.dropout, is_causal=True
)
return y
Doing this forward pass works fine:
model = Model().cuda()
batch_size = 2
x = torch.randint(0, 50257, (batch_size, 1024)).cuda()
y = model(x)
However, if I compile the model:
model = torch.compile(model)
_ = model(x)
I get an indexing error (see below for error message).
However, if I remove the “query/key/value” parameter names, then compiling works correctly
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, n_embd=768, bias=False):
super().__init__()
self.n_head = 6
self.embd = nn.Embedding(50257, 768)
self.c_attn = nn.Linear(in_features=n_embd, out_features=3 * n_embd, bias=bias)
self.dropout = 0
def forward(self, x):
x = self.embd(x)
(B, T, C) = x.size()
q, k, v = self.c_attn(x).chunk(chunks=3, dim=-1)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
)
return y
model = Model().cuda()
batch_size = 2
x = torch.randint(0, 50257, (batch_size, 1024)).cuda()
model = torch.compile(model)
_ = model(x)
y = model(x)
I’m running the torch nightly: ‘2.0.0.dev20230220+cu118’
and triton==2.0.0.a2
Is this just a triton error or something else? Thanks!
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
323 try:
--> 324 out_code = transform_code_object(code, transform)
325 orig_code_map[out_code] = code
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445, in transform_code_object(code, transformations, safe)
443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:311, in _compile.<locals>.transform(instructions, code_options)
299 tracer = InstructionTranslator(
300 instructions,
301 code,
(...)
309 mutated_closure_cell_contents,
310 )
--> 311 tracer.run()
312 output = tracer.output
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1738, in InstructionTranslator.run(self)
1737 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1738 super().run()
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:588, in InstructionTranslatorBase.run(self)
584 self.output.push_tx(self)
585 while (
586 self.instruction_pointer is not None
587 and not self.output.should_exit
--> 588 and self.step()
589 ):
590 pass
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:552, in InstructionTranslatorBase.step(self)
551 unimplemented(f"missing: {inst.opname}")
--> 552 getattr(self, inst.opname)(inst)
554 return inst.opname != "RETURN_VALUE"
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:342, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
341 try:
--> 342 return inner_fn(self, inst)
343 except Unsupported as excp:
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1026, in InstructionTranslatorBase.CALL_FUNCTION_KW(self, inst)
1025 assert len(kwargs) == len(argnames)
-> 1026 self.call_function(fn, args, kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:486, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
485 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 486 self.push(fn.call_function(self, args, kwargs))
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py:484, in TorchVariable.call_function(self, tx, args, kwargs)
481 if self.value == torch._C._nn.scaled_dot_product_attention:
482 # See:[Note] SDPA_flash's meta function returns incorrect Philox seed and offset
483 # in pytorch/torch/_meta_registrations.py
--> 484 fake_query = args[0].as_proxy().node.meta["example_value"]
485 fake_key = args[1].as_proxy().node.meta["example_value"]
IndexError: list index out of range
from user code:
File "/tmp/ipykernel_24294/2653966783.py", line 19, in forward
y = torch.nn.functional.scaled_dot_product_attention(
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
The above exception was the direct cause of the following exception:
InternalTorchDynamoError Traceback (most recent call last)
Cell In[5], line 2
1 model = torch.compile(model)
----> 2 _ = model(x)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:82, in OptimizedModule.forward(self, *args, **kwargs)
81 def forward(self, *args, **kwargs):
---> 82 return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
207 dynamic_ctx.__enter__()
208 try:
--> 209 return fn(*args, **kwargs)
210 finally:
211 set_eval_frame(prior)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
334 return hijacked_callback(frame, cache_size, hooks)
336 with compile_lock:
--> 337 return callback(frame, cache_size, hooks)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:404, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
402 counters["frames"]["total"] += 1
403 try:
--> 404 result = inner_convert(frame, cache_size, hooks)
405 counters["frames"]["ok"] += 1
406 return result
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104, in wrap_convert_context.<locals>._fn(*args, **kwargs)
102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
103 try:
--> 104 return fn(*args, **kwargs)
105 finally:
106 torch._C._set_grad_enabled(prior_grad_mode)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
259 global initial_grad_state
260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
263 frame.f_code,
264 frame.f_globals,
265 frame.f_locals,
266 frame.f_builtins,
267 compiler_fn,
268 one_graph,
269 export,
270 hooks,
271 frame,
272 )
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py:163, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
161 compilation_metrics[key] = []
162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
164 time_spent = time.time() - t0
165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:394, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
392 except Exception as e:
393 exception_handler(e, code, frame)
--> 394 raise InternalTorchDynamoError() from e
InternalTorchDynamoError: