Unable to lower to STABLEHLO hugging face ViT model

Code snippet that is causing the error:

def compile_torch_to_mhlo(model,data):
    print('Compile torch program to mhlo test\n------\n')

    import torch_mlir

    module = torch_mlir.compile(
        model,
        data,
        output_type=torch_mlir.OutputType.STABLEHLO,
        use_tracing = False
    )
    print(f"StableHLO={module}\n------\n")

if __name__ == '__main__':
...
data = torch.ones(args.batch,3,224,224)
config = AutoConfig.from_pretrained(args.model_type,num_labels=num_classes)
model = CustomViTForImageClassification(config)
model.load_state_dict(checkpoint['state_dict'], strict = False)
compile_torch_to_mhlo(model, data)

Error trace below:

/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/annotations.py:386: UserWarning: TorchScript will treat type annotations of Tensor dtype-specific subtypes as if they are normal Tensors. dtype constraints are not enforced in compilation either.
  warnings.warn(
Traceback (most recent call last):
  File "/home/nhd7682/SNL_VIT/mpc_inference.py", line 164, in <module>
    compile_torch_to_mhlo(model, data)
  File "/home/nhd7682/SNL_VIT/mpc_inference.py", line 132, in compile_torch_to_mhlo
    module = torch_mlir.compile(
             ^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch_mlir/__init__.py", line 419, in compile
    scripted = torch.jit.script(model)
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
    create_methods_and_properties_from_stubs(
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(
RuntimeError: 
'NoneType' object has no attribute or method 'expand'.:
  File "/scratch/nhd7682/envs_dirs/spu/lib/python3.11/site-packages/transformers/models/vit/modeling_vit.py", line 126
        if bool_masked_pos is not None:
            seq_length = embeddings.shape[1]
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
                          ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)

I am not sure if this is a bug or I am doing something incorrectly. I am using Torch-MLIR to lower the model to StableHLO which in turn using torch.jit.script method as evident from the error trace.