Hi,
I am trying to run the TRM model on a machine with rocm. The following code is auto-generated by torch.compile.
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 131072},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i1', 'in_ptr1': '*i32', 'in_ptr2': '*i32', 'out_ptr0': '*i32', 'xnumel': 'i3
2'}, 'device': DeviceProperties(type='hip', index=5, multi_processor_count=110, cc='gfx90a', major=9, regs_per_multiproc
essor=65536, max_threads_per_multi_processor=2048, warp_size=64), 'constants': {}, 'configs': [<triton.backends.compiler
.AttrsDescriptor object at 0x7ffec0230040>]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_where_0', 'mutated_a
rg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '2708FF83C40
B95EEA10411DB3D79BBB9FAE0D4AD0AA5BDC1CF2FAD7683AEFA13', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_
indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable
_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_r
block': 256, 'spill_threshold': 16, 'store_cubin': False, 'is_hip': True, 'tiling_scores': {'x': 1382400}},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_where_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 86400
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex // 900
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tmp2 = tl.load(in_ptr2 + (x2), xmask)
tmp3 = tl.where(tmp0, tmp1, tmp2)
tl.store(out_ptr0 + (x2), tmp3, xmask)```
As you can see, in line 14, the ‘configs’ array has the value [<triton.backends.compiler
.AttrsDescriptor object at 0x7ffec0230040>]. That causes a SyntaxError saying invalid syntax. Is there something I can do to resolve this? I am using PyTorch 2.8.0+rocm6.4.