Hello @trusira @marksaroufim
I followed your lead and made the following changes:
torch._inductor.config.max_autotune_gemm_backends = "TRITON" # removed ATEN
torch._inductor.config.max_autotune = True
But I still get an error as follows:
File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py", line 156, in tuned_mm
return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 991, in autotune_select_algorithm
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 723, in __call__
raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: RuntimeError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
target: aten.mm.default
args[0]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[100], stride=[1]))
),
FixedLayout('cuda', torch.float32, size=[1, 100], stride=[100, 1]),
origins={view}
)
)
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[100, 100], stride=[100, 1]))
),
FixedLayout('cuda', torch.float32, size=[100, 100], stride=[1, 100]),
origins={permute}
)
)
#Repro
import torch
import torch._inductor.config
torch._inductor.config.trace.enabled = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
torch._inductor.config.max_autotune = True
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(100, 100)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.l(x))
m = ToyModel().to(device="cuda:0")
m = torch.compile(m)
input_tensor = torch.randn(100).to(device="cuda:0")
out = m(input_tensor)