The compiling of forward worked just fine, but failed on backward:
torch/_inductor/ir.py, line 2455, in _normalize_size:
assert len(new_size) == len(old_size)
torch._inductor.exc.LoweringException: AssertionError:
target: aten.index_put.default
args[0]: TensorBox(StorageBox(
Pointwise(
‘cuda’,
torch.float16,
def innner_fn(index):
_, i1, i2, i3, i4, i5 = index
tmp0 = ops.constant(0, torch.float16)
return tmp0
,
ranges=[1, 239, 522, 8, 64, 2],
origin_node=full_default,
origins=OrderedSet([full_defult])
)
))
args[1]: [None, None, None, None, None, TensorBox(StorageBox(
Pointwise(
‘cuda’,
torch.int64,
def inner_fn(index):
tmp0 = ops.constant(1, torch.int64)
return tmp0
,
ranges=[],
origin_node=lift_fresh_copy_12,
origins=OrderedSet({lift_fresh_copy_12])
)
))]
args[2]: TensorBox(StorageBox(
Pointwise(
‘cuda’,
torch.float16,
def inner_fn(index):
_, i1, i2, i3, i4, _ = index
tmp0 = ops.load(tangents_36, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
tmp1 = ops.load(tangents_29, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
tmp2 = ops.load(buf16, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
tmp3 = -tmp2
tmp4 = tmp1 + tmp3
tmp5 = ops.load(_tensor_constant14, i4 + 64 * i1)
tmp6 = tmp4 * tmp5
tmp7 = ops.to_dtype(tmp6, torch.float16, src_dtype=torch.float32)
tmp8 = ops.load(tangents_39, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
tmp9 = ops.load(buf17, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
tmp10 = tmp8 + tmp9
tmp11 = ops.load(_tensor_constant13, i4 + 64 * i1)
tmp12 = tmp10 * tmp11
tmp13 = ops.to_dtype(tmp12, torch.float16, src_dtype=torch.float32)
tmp14 = tmp7 + tmp13
tmp15 = tmp0 + tmp14
return tmp15
,
ranges=[1, 239, 522, 8, 64, 1],
origin_node=add_28,
origins=OrderedSet([add_28])
)
))
args[3]: True
It seems like the error is caused by torch.index_select:
def forward(…):
…
x = x.reshape(1, 239, 522, 8, 64, 2)
x = torch.index_select(x, dim=5, index=torch.tensor(0, dtype=torch.int64, device=‘cuda:0‘))
x = x.squeeze(5)
x = torch.mul(x, y)
...
I really not how to debug in backward while using torch.compile. Need some help to understand the reason of this error.