Torch.compile error on backward

The compiling of forward worked just fine, but failed on backward:

torch/_inductor/ir.py, line 2455, in _normalize_size:   
  assert len(new_size) == len(old_size)
torch._inductor.exc.LoweringException: AssertionError:
  target: aten.index_put.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      ‘cuda’,
      torch.float16,
      def innner_fn(index):
        _, i1, i2, i3, i4, i5 = index
        tmp0 = ops.constant(0, torch.float16)
        return tmp0
      ,
      ranges=[1, 239, 522, 8, 64, 2],
      origin_node=full_default,
      origins=OrderedSet([full_defult])
    )
  ))
  args[1]: [None, None, None, None, None, TensorBox(StorageBox(
    Pointwise(
      ‘cuda’,
      torch.int64,
      def inner_fn(index):
        tmp0 = ops.constant(1, torch.int64)
        return tmp0
      ,
      ranges=[],
      origin_node=lift_fresh_copy_12,
      origins=OrderedSet({lift_fresh_copy_12])
    )
  ))]
  args[2]: TensorBox(StorageBox(
    Pointwise(
      ‘cuda’,
      torch.float16,
      def inner_fn(index):
        _, i1, i2, i3, i4, _ = index
        tmp0 = ops.load(tangents_36, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
        tmp1 = ops.load(tangents_29, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
        tmp2 = ops.load(buf16, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
        tmp3 = -tmp2
        tmp4 = tmp1 + tmp3
        tmp5 = ops.load(_tensor_constant14, i4 + 64 * i1)
        tmp6 = tmp4 * tmp5
        tmp7 = ops.to_dtype(tmp6, torch.float16, src_dtype=torch.float32)
        tmp8 = ops.load(tangents_39, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
        tmp9 = ops.load(buf17, i4 + 64 * i3 + 512 * i2 + 267264 * i1)
        tmp10 = tmp8 + tmp9
        tmp11 = ops.load(_tensor_constant13, i4 + 64 * i1)
        tmp12 = tmp10 * tmp11
        tmp13 = ops.to_dtype(tmp12, torch.float16, src_dtype=torch.float32)
        tmp14 = tmp7 + tmp13
        tmp15 = tmp0 + tmp14
        return tmp15
      ,
      ranges=[1, 239, 522, 8, 64, 1],
      origin_node=add_28,
      origins=OrderedSet([add_28])
    )
  ))
  args[3]: True

It seems like the error is caused by torch.index_select:

def forward(…):
  …
  x = x.reshape(1, 239, 522, 8, 64, 2)
  x = torch.index_select(x, dim=5, index=torch.tensor(0, dtype=torch.int64, device=‘cuda:0‘))
  x = x.squeeze(5)
  x = torch.mul(x, y)
  ...

I really not how to debug in backward while using torch.compile. Need some help to understand the reason of this error.:face_with_medical_mask: