Strange "IndexError" when compiling

While waiting for this to get resolved:

I stumbled on this:

pip install --pre torch torchvision torchaudio --index-url
pip uninstall pytorch-triton 
pip install --no-deps triton==2.0.0.a2
TRITON_PTXAS_PATH=/usr/bin/ptxas python

and everything started working :slight_smile:

BUT then I ran into a strange issue here:

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, n_embd=768, bias=False):
        self.n_head = 6
        self.embd = nn.Embedding(50257, 768)
        self.c_attn = nn.Linear(in_features=n_embd, out_features=3 * n_embd, bias=bias)
        self.dropout = 0

    def forward(self, x):
        x = self.embd(x)
        (B, T, C) = x.size()
        q, k, v = self.c_attn(x).chunk(chunks=3, dim=-1)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, attn_mask=None, dropout_p=self.dropout, is_causal=True
        return y

Doing this forward pass works fine:

model = Model().cuda()
batch_size = 2
x = torch.randint(0, 50257, (batch_size, 1024)).cuda()
y = model(x)

However, if I compile the model:

model = torch.compile(model)
_ = model(x)

I get an indexing error (see below for error message).

However, if I remove the “query/key/value” parameter names, then compiling works correctly

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, n_embd=768, bias=False):
        self.n_head = 6
        self.embd = nn.Embedding(50257, 768)
        self.c_attn = nn.Linear(in_features=n_embd, out_features=3 * n_embd, bias=bias)
        self.dropout = 0

    def forward(self, x):
        x = self.embd(x)
        (B, T, C) = x.size()
        q, k, v = self.c_attn(x).chunk(chunks=3, dim=-1)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
        return y
model = Model().cuda()
batch_size = 2
x = torch.randint(0, 50257, (batch_size, 1024)).cuda()
model = torch.compile(model)
_ = model(x)
y = model(x)

I’m running the torch nightly: ‘2.0.0.dev20230220+cu118’
and triton==2.0.0.a2

Is this just a triton error or something else? Thanks!

IndexError                                Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    323 try:
--> 324     out_code = transform_code_object(code, transform)
    325     orig_code_map[out_code] = code

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in transform_code_object(code, transformations, safe)
    443 propagate_line_nums(instructions)
--> 445 transformations(instructions, code_options)
    446 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in _compile.<locals>.transform(instructions, code_options)
    299 tracer = InstructionTranslator(
    300     instructions,
    301     code,
    309     mutated_closure_cell_contents,
    310 )
--> 311
    312 output = tracer.output

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in
   1737 _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
-> 1738 super().run()

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in
    584 self.output.push_tx(self)
    585 while (
    586     self.instruction_pointer is not None
    587     and not self.output.should_exit
--> 588     and self.step()
    589 ):
    590     pass

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in InstructionTranslatorBase.step(self)
    551     unimplemented(f"missing: {inst.opname}")
--> 552 getattr(self, inst.opname)(inst)
    554 return inst.opname != "RETURN_VALUE"

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    341 try:
--> 342     return inner_fn(self, inst)
    343 except Unsupported as excp:

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in InstructionTranslatorBase.CALL_FUNCTION_KW(self, inst)
   1025 assert len(kwargs) == len(argnames)
-> 1026 self.call_function(fn, args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    485     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 486 self.push(fn.call_function(self, args, kwargs))

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/, in TorchVariable.call_function(self, tx, args, kwargs)
    481 if self.value == torch._C._nn.scaled_dot_product_attention:
    482     # See:[Note] SDPA_flash's meta function returns incorrect Philox seed and offset
    483     # in pytorch/torch/
--> 484     fake_query = args[0].as_proxy().node.meta["example_value"]
    485     fake_key = args[1].as_proxy().node.meta["example_value"]

IndexError: list index out of range

from user code:
   File "/tmp/ipykernel_24294/", line 19, in forward
    y = torch.nn.functional.scaled_dot_product_attention(

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

The above exception was the direct cause of the following exception:

InternalTorchDynamoError                  Traceback (most recent call last)
Cell In[5], line 2
      1 model = torch.compile(model)
----> 2 _ = model(x)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in OptimizedModule.forward(self, *args, **kwargs)
     81 def forward(self, *args, **kwargs):
---> 82     return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    207 dynamic_ctx.__enter__()
    208 try:
--> 209     return fn(*args, **kwargs)
    210 finally:
    211     set_eval_frame(prior)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_size)
    334             return hijacked_callback(frame, cache_size, hooks)
    336 with compile_lock:
--> 337     return callback(frame, cache_size, hooks)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in convert_frame.<locals>._convert_frame(frame, cache_size, hooks)
    402 counters["frames"]["total"] += 1
    403 try:
--> 404     result = inner_convert(frame, cache_size, hooks)
    405     counters["frames"]["ok"] += 1
    406     return result

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in wrap_convert_context.<locals>._fn(*args, **kwargs)
    102 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
    103 try:
--> 104     return fn(*args, **kwargs)
    105 finally:
    106     torch._C._set_grad_enabled(prior_grad_mode)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_size, hooks)
    259 global initial_grad_state
    260 initial_grad_state = torch.is_grad_enabled()
--> 262 return _compile(
    263     frame.f_code,
    264     frame.f_globals,
    265     frame.f_locals,
    266     frame.f_builtins,
    267     compiler_fn,
    268     one_graph,
    269     export,
    270     hooks,
    271     frame,
    272 )

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    161     compilation_metrics[key] = []
    162 t0 = time.time()
--> 163 r = func(*args, **kwargs)
    164 time_spent = time.time() - t0
    165 # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)
    392 except Exception as e:
    393     exception_handler(e, code, frame)
--> 394     raise InternalTorchDynamoError() from e


Thanks for reporting this issue! Would you mind creating an issue on GitHub so that we could track and fix it, please?

Done, thanks @ptrblck !

1 Like