View operator is not compiled when receiving Dtensor + FakeTensor as an input

When I parallelized the Conv1D layer (from transformers library) with DTensor and torch.compile with the following code, the error occurs. It seems that view operator is not compiled properly when the input tensor is Dtensor wrapping FakeTensor. What I am trying is to do this using real tensor instead of fake tensor. Does anyone know how to do this? or is there any other workaround?

update: I found out how to trace with real tensor. But still I want to do this with fake tensor.

from transformers import Conv1D
import torch.nn as nn
import torch
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor

mesh = DeviceMesh("cpu", list(range(2)))

def my_fn(gm, inputs):
    print(gm.graph)
    return gm

model = Conv1D(8, 16)

# parallelize with DTensor
model.weight = nn.Parameter(distribute_tensor(model.weight, mesh, [Shard(1)]))
model.bias = nn.Parameter(distribute_tensor(model.bias, mesh, [Shard(0)]))

input = distribute_tensor(torch.randn(4, 16), mesh, [Shard(1)])
torch.compile(model, backend=my_fn)(input)

Error message

(pytorch2) root@sunghwanshim-cpu-0:~/pytorch# torchrun --nnodes 1 --nproc-per-node 2 dtensor_test2.py 
[2023-10-25 12:04:23,314] torch.distributed.run: [WARNING] 
[2023-10-25 12:04:23,314] torch.distributed.run: [WARNING] *****************************************
[2023-10-25 12:04:23,314] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2023-10-25 12:04:23,314] torch.distributed.run: [WARNING] *****************************************
Traceback (most recent call last):
  File "dtensor_test2.py", line 19, in <module>
    torch.compile(model, backend=my_fn)(input)
  File "/root/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
  File "dtensor_test2.py", line 19, in <module>
    return forward_call(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    torch.compile(model, backend=my_fn)(input)
  File "/root/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return fn(*args, **kwargs)
  File "/root/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)    
return self._call_impl(*args, **kwargs)
  File "/root/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/root/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    return forward_call(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
    return fn(*args, **kwargs)
      File "/root/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return callback(frame, cache_entry, hooks, frame_state)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    return self._call_impl(*args, **kwargs)
  File "/root/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
    return _compile(
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    return forward_call(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
    out_code = transform_code_object(code, transform)
  File "/root/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    return callback(frame, cache_entry, hooks, frame_state)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
    transformations(instructions, code_options)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
    return fn(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
    tracer.run()
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    return _compile(
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    super().run()
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
        transformations(instructions, code_options)and self.step()

  File "/root/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    tracer.run()
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
    getattr(self, inst.opname)(inst)
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
        self.call_function(fn, args, {})super().run()

  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
    self.push(fn.call_function(self, args, kwargs))
  File "/root/pytorch/torch/_dynamo/variables/misc.py", line 594, in call_function
    and self.step()
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
    getattr(self, inst.opname)(inst)
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  File "/root/pytorch/torch/_dynamo/variables/tensor.py", line 652, in call_method
    return inner_fn(self, inst)
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
    return wrap_fx_proxy(
  File "/root/pytorch/torch/_dynamo/variables/builder.py", line 1207, in wrap_fx_proxy
    self.call_function(fn, args, {})
  File "/root/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
    return wrap_fx_proxy_cls(
  File "/root/pytorch/torch/_dynamo/variables/builder.py", line 1294, in wrap_fx_proxy_cls
    self.push(fn.call_function(self, args, kwargs))
  File "/root/pytorch/torch/_dynamo/variables/misc.py", line 594, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  File "/root/pytorch/torch/_dynamo/variables/tensor.py", line 652, in call_method
    example_value = get_fake_value(proxy.node, tx)
  File "/root/pytorch/torch/_dynamo/utils.py", line 1381, in get_fake_value
    return wrap_fx_proxy(
  File "/root/pytorch/torch/_dynamo/variables/builder.py", line 1207, in wrap_fx_proxy
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/root/pytorch/torch/_dynamo/utils.py", line 1342, in get_fake_value
    return wrap_fx_proxy_cls(
  File "/root/pytorch/torch/_dynamo/variables/builder.py", line 1294, in wrap_fx_proxy_cls
    return wrap_fake_exception(
  File "/root/pytorch/torch/_dynamo/utils.py", line 917, in wrap_fake_exception
    example_value = get_fake_value(proxy.node, tx)
  File "/root/pytorch/torch/_dynamo/utils.py", line 1381, in get_fake_value
    return fn()
  File "/root/pytorch/torch/_dynamo/utils.py", line 1343, in <lambda>
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/root/pytorch/torch/_dynamo/utils.py", line 1342, in get_fake_value
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/root/pytorch/torch/_dynamo/utils.py", line 1415, in run_node
    return wrap_fake_exception(
  File "/root/pytorch/torch/_dynamo/utils.py", line 917, in wrap_fake_exception
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "/root/pytorch/torch/_dynamo/utils.py", line 1404, in run_node
    return fn()
  File "/root/pytorch/torch/_dynamo/utils.py", line 1343, in <lambda>
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "/root/pytorch/torch/_tensor.py", line 1386, in __torch_function__
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/root/pytorch/torch/_dynamo/utils.py", line 1415, in run_node
    ret = func(*args, **kwargs)
  File "/root/pytorch/torch/distributed/_tensor/api.py", line 241, in __torch_dispatch__
        raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from ereturn op_dispatch.operator_dispatch(

  File "/root/pytorch/torch/_dynamo/utils.py", line 1404, in run_node
  File "/root/pytorch/torch/distributed/_tensor/dispatch.py", line 109, in operator_dispatch
    out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator)
  File "/root/pytorch/torch/distributed/_tensor/dispatch.py", line 183, in _operator_dispatch
    sharding_propagator.propagate(op_info)
  File "/root/pytorch/torch/distributed/_tensor/sharding_prop.py", line 57, in propagate
    output_sharding = self.propagate_op_sharding(op_overload, op_info.schema)
  File "/root/pytorch/torch/distributed/_tensor/sharding_prop.py", line 159, in propagate_op_sharding
    return getattr(args[0], node.target)(*args[1:], **kwargs)
  File "/root/pytorch/torch/_tensor.py", line 1386, in __torch_function__
    f"Sharding propagation failed on op {op_overload}.\n"
  File "/root/pytorch/torch/distributed/_tensor/op_schema.py", line 174, in __repr__
    f"OpSchema(func_schema={self.func_schema},"
  File "/root/miniconda3/envs/pytorch2/lib/python3.8/dataclasses.py", line 368, in wrapper
    result = user_function(self)
  File "<string>", line 3, in __repr__
    ret = func(*args, **kwargs)  File "/root/pytorch/torch/distributed/_tensor/device_mesh.py", line 259, in __repr__

  File "/root/pytorch/torch/distributed/_tensor/api.py", line 241, in __torch_dispatch__
        return op_dispatch.operator_dispatch(return f"DeviceMesh:({self.mesh.tolist()})"

  File "/root/pytorch/torch/distributed/_tensor/dispatch.py", line 109, in operator_dispatch
  File "/root/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)    
out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator)
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1290, in __torch_dispatch__
  File "/root/pytorch/torch/distributed/_tensor/dispatch.py", line 183, in _operator_dispatch
    sharding_propagator.propagate(op_info)
  File "/root/pytorch/torch/distributed/_tensor/sharding_prop.py", line 57, in propagate
    output_sharding = self.propagate_op_sharding(op_overload, op_info.schema)
  File "/root/pytorch/torch/distributed/_tensor/sharding_prop.py", line 159, in propagate_op_sharding
    f"Sharding propagation failed on op {op_overload}.\n"
  File "/root/pytorch/torch/distributed/_tensor/op_schema.py", line 174, in __repr__
    return self.dispatch(func, types, args, kwargs)
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1421, in dispatch
    f"OpSchema(func_schema={self.func_schema},"
  File "/root/miniconda3/envs/pytorch2/lib/python3.8/dataclasses.py", line 368, in wrapper
    result = user_function(self)
  File "<string>", line 3, in __repr__
  File "/root/pytorch/torch/distributed/_tensor/device_mesh.py", line 259, in __repr__
        return f"DeviceMesh:({self.mesh.tolist()})") = self.validate_and_convert_non_fake_tensors(func, converter, args, kwargs)

  File "/root/pytorch/torch/utils/_stats.py", line 20, in wrapper
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1642, in validate_and_convert_non_fake_tensors
    return fn(*args, **kwargs)
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1290, in __torch_dispatch__
        args, kwargs = tree_map_only(return self.dispatch(func, types, args, kwargs)

  File "/root/pytorch/torch/utils/_pytree.py", line 362, in tree_map_only
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1421, in dispatch
    return tree_map(map_only(ty)(fn), pytree)
  File "/root/pytorch/torch/utils/_pytree.py", line 292, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/root/pytorch/torch/utils/_pytree.py", line 292, in <listcomp>
    ) = self.validate_and_convert_non_fake_tensors(func, converter, args, kwargs)
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1642, in validate_and_convert_non_fake_tensors
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/root/pytorch/torch/utils/_pytree.py", line 343, in inner
    return f(x)
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1632, in validate
    args, kwargs = tree_map_only(
  File "/root/pytorch/torch/utils/_pytree.py", line 362, in tree_map_only
    return tree_map(map_only(ty)(fn), pytree)
  File "/root/pytorch/torch/utils/_pytree.py", line 292, in tree_map
    raise Exception(
torch._dynamo.exc.TorchRuntimeError: Failed running call_method view(*(DTensor(local_tensor=FakeTensor(..., size=(4, 8)), device_mesh=DeviceMesh:([0, 1]), placements=(Shard(dim=1),)), -1, 16), **{}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.resolve_conj.default(tensor([...], size=(2,), dtype=torch.int32))

from user code:
   File "/root/miniconda3/envs/pytorch2/lib/python3.8/site-packages/transformers/pytorch_utils.py", line 108, in forward
    x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/root/pytorch/torch/utils/_pytree.py", line 292, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/root/pytorch/torch/utils/_pytree.py", line 343, in inner
    return f(x)
  File "/root/pytorch/torch/_subclasses/fake_tensor.py", line 1632, in validate
    raise Exception(
torch._dynamo.exc.TorchRuntimeError: Failed running call_method view(*(DTensor(local_tensor=FakeTensor(..., size=(4, 8)), device_mesh=DeviceMesh:([0, 1]), placements=(Shard(dim=1),)), -1, 16), **{}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.resolve_conj.default(tensor([...], size=(2,), dtype=torch.int32))

from user code:
   File "/root/miniconda3/envs/pytorch2/lib/python3.8/site-packages/transformers/pytorch_utils.py", line 108, in forward
    x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2023-10-25 12:04:28,336] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 137293) of binary: /root/miniconda3/envs/pytorch2/bin/python
Traceback (most recent call last):
  File "/root/miniconda3/envs/pytorch2/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
  File "/root/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/root/pytorch/torch/distributed/run.py", line 806, in main
    run(args)
  File "/root/pytorch/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/root/pytorch/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/root/pytorch/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
dtensor_test2.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-10-25_12:04:28
  host      : sunghwanshim-cpu-0
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 137294)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-10-25_12:04:28
  host      : sunghwanshim-cpu-0
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 137293)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Versions

PyTorch version: 2.2.0a0+gite4f3e54
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: 11.0.1-2
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.8.18 (default, Sep 11 2023, 13:40:15)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-87-generic-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             192
On-line CPU(s) list:                0-191
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          4
NUMA node(s):                       4
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Gold 6348H CPU @ 2.30GHz
Stepping:                           11
CPU MHz:                            1000.000
CPU max MHz:                        4200.0000
CPU min MHz:                        1000.0000
BogoMIPS:                           4600.00
Virtualization:                     VT-x
L1d cache:                          3 MiB
L1i cache:                          3 MiB
L2 cache:                           96 MiB
L3 cache:                           132 MiB
NUMA node0 CPU(s):                  0-23,96-119
NUMA node1 CPU(s):                  24-47,120-143
NUMA node2 CPU(s):                  48-71,144-167
NUMA node3 CPU(s):                  72-95,168-191
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] optree==0.9.2
[pip3] torch==2.2.0a0+gite4f3e54
[conda] mkl                       2023.1.0         h213fc3f_46343  
[conda] mkl-include               2023.1.0         h06a4308_46343  
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] optree                    0.9.2                    pypi_0    pypi
[conda] torch                     2.2.0a0+gite4f3e54           dev_0    <develop>