Torch.compile: AttributeError: module 'z3' has no attribute 'ExprRef'

Hi,

For a regression problem with pytorch 2.1.2+cpu, my code works fine as long as I do not use torch.compile With torch.compile, I get the following error message

    outputs = model(x_train)
              ^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 429, in _compile
    from torch.fx.experimental.validator import (
  File "/usr/lib64/python3.11/site-packages/torch/fx/experimental/validator.py", line 58, in <module>
    def z3str(e: z3.ExprRef) -> str:
                 ^^^^^^^^^^
AttributeError: module 'z3' has no attribute 'ExprRef'

Here is the code to reproduce:

import torch
import torch.nn as nn
import numpy as np
from time import time

input_size = 1
output_size = 1
hidden_size = 500
num_data = 100000

# seeds
seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)

# hyper-parameters
num_epochs = 50
learning_rate = 0.01

# toy dataset
x_train = np.random.rand(num_data,input_size)
y_train = np.cos(2*np.pi*x_train) + 0.1*np.random.randn(num_data,input_size)

# regression model
model = nn.Sequential(nn.Linear(input_size, hidden_size),
                      nn.GELU(),
                      nn.Linear(hidden_size, hidden_size),
                      nn.GELU(),
                      nn.Linear(hidden_size, output_size))
if 1:
    model = torch.compile(model)

# loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.NAdam(model.parameters(), lr=learning_rate)  

# train the model
x_train = torch.from_numpy(x_train.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))

for epoch in range(num_epochs):
    start_time = time()
    
    # forward pass
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    
    # backward and optimize
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    print(f'Epoch {epoch}: Loss: {loss.item():.2e} in {time()-start_time:.1f}s.')

I thought, it has to do with my environment; I had installed pytorch with pip, so alternatively I installed it with conda, but the problem is still the same.

Thanks for any help.

I cannot reproduce the issue using the current nightly build torch==2.3.0.dev20240108+cu121 and the current stable 2.1.2+cu121 release:

pip install torch torchvision 
Collecting torch
  Downloading torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl.metadata (25 kB)
...
Installing collected packages: mpmath, urllib3, typing-extensions, sympy, pillow, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, networkx, MarkupSafe, idna, fsspec, filelock, charset-normalizer, certifi, triton, requests, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jinja2, nvidia-cusolver-cu12, torch, torchvision
Successfully installed MarkupSafe-2.1.3 certifi-2023.11.17 charset-normalizer-3.3.2 filelock-3.13.1 fsspec-2023.12.2 idna-3.6 jinja2-3.1.2 mpmath-1.3.0 networkx-3.2.1 numpy-1.26.3 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.3.101 nvidia-nvtx-cu12-12.1.105 pillow-10.2.0 requests-2.31.0 sympy-1.12 torch-2.1.2 torchvision-0.16.2 triton-2.1.0 typing-extensions-4.9.0 urllib3-2.1.0

Output:

Epoch 0: Loss: 5.13e-01 in 7.9s.
Epoch 1: Loss: 1.06e+01 in 0.5s.
Epoch 2: Loss: 2.80e+01 in 0.5s.
Epoch 3: Loss: 4.88e+00 in 0.5s.
Epoch 4: Loss: 1.28e+00 in 0.5s.
Epoch 5: Loss: 4.75e-01 in 0.5s.
Epoch 6: Loss: 4.69e-01 in 0.6s.
Epoch 7: Loss: 4.62e-01 in 0.5s.
Epoch 8: Loss: 4.53e-01 in 0.5s.
Epoch 9: Loss: 4.43e-01 in 0.5s.
Epoch 10: Loss: 4.31e-01 in 0.5s.
Epoch 11: Loss: 4.17e-01 in 0.5s.
Epoch 12: Loss: 3.99e-01 in 0.5s.
Epoch 13: Loss: 3.76e-01 in 0.5s.
Epoch 14: Loss: 3.47e-01 in 0.5s.
Epoch 15: Loss: 3.13e-01 in 0.5s.
Epoch 16: Loss: 2.84e-01 in 0.5s.
Epoch 17: Loss: 3.12e-01 in 0.5s.
Epoch 18: Loss: 4.34e-01 in 0.5s.
Epoch 19: Loss: 5.88e-01 in 0.5s.
Epoch 20: Loss: 3.54e-01 in 0.5s.
Epoch 21: Loss: 2.17e-01 in 0.5s.
Epoch 22: Loss: 1.81e-01 in 0.5s.
Epoch 23: Loss: 1.54e-01 in 0.5s.
Epoch 24: Loss: 1.32e-01 in 0.5s.
Epoch 25: Loss: 1.12e-01 in 0.5s.
Epoch 26: Loss: 9.52e-02 in 0.5s.
Epoch 27: Loss: 8.51e-02 in 0.5s.
Epoch 28: Loss: 1.00e-01 in 0.5s.
Epoch 29: Loss: 2.27e-01 in 0.5s.
Epoch 30: Loss: 4.47e-01 in 0.5s.
Epoch 31: Loss: 7.02e-01 in 0.5s.
Epoch 32: Loss: 2.08e-01 in 0.5s.
Epoch 33: Loss: 8.48e-02 in 0.5s.
Epoch 34: Loss: 6.81e-02 in 0.5s.
Epoch 35: Loss: 5.76e-02 in 0.5s.
Epoch 36: Loss: 4.92e-02 in 0.5s.
Epoch 37: Loss: 4.22e-02 in 0.5s.
Epoch 38: Loss: 3.64e-02 in 0.5s.
Epoch 39: Loss: 3.17e-02 in 0.5s.
Epoch 40: Loss: 2.79e-02 in 0.5s.
Epoch 41: Loss: 2.50e-02 in 0.5s.
Epoch 42: Loss: 2.27e-02 in 0.5s.
Epoch 43: Loss: 2.10e-02 in 0.5s.
Epoch 44: Loss: 1.98e-02 in 0.5s.
Epoch 45: Loss: 1.89e-02 in 0.5s.
Epoch 46: Loss: 1.82e-02 in 0.5s.
Epoch 47: Loss: 1.78e-02 in 0.5s.
Epoch 48: Loss: 1.75e-02 in 0.5s.
Epoch 49: Loss: 1.73e-02 in 0.5s.

Thanks for the information! Do you have the possibility to test using a cpu only version? If not, this is no big deal. Your answer seems to point to the fact that my environment is troublesome.

2.1.2+cpu still works for me:

pip install torch --index-url https://download.pytorch.org/whl/cpu
Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torch
  Downloading https://download.pytorch.org/whl/cpu/torch-2.1.2%2Bcpu-cp310-cp310-linux_x86_64.whl (184.9 MB)
...
Installing collected packages: mpmath, typing-extensions, sympy, networkx, MarkupSafe, fsspec, filelock, jinja2, torch
Successfully installed MarkupSafe-2.1.3 filelock-3.9.0 fsspec-2023.4.0 jinja2-3.1.2 mpmath-1.3.0 networkx-3.0 sympy-1.12 torch-2.1.2+cpu typing-extensions-4.4.0
...
pip install numpy
Collecting numpy
  Downloading numpy-1.26.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.2/61.2 kB 615.3 kB/s eta 0:00:00
Downloading numpy-1.26.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.2/18.2 MB 19.7 MB/s eta 0:00:00
Installing collected packages: numpy
Successfully installed numpy-1.26.3

Output:

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
Epoch 0: Loss: 5.13e-01 in 8.0s.
Epoch 1: Loss: 1.06e+01 in 0.5s.
Epoch 2: Loss: 2.80e+01 in 0.5s.
Epoch 3: Loss: 4.88e+00 in 0.5s.
Epoch 4: Loss: 1.28e+00 in 0.5s.
Epoch 5: Loss: 4.75e-01 in 0.5s.
Epoch 6: Loss: 4.69e-01 in 0.5s.
Epoch 7: Loss: 4.62e-01 in 0.5s.
Epoch 8: Loss: 4.53e-01 in 0.5s.
Epoch 9: Loss: 4.43e-01 in 0.5s.
Epoch 10: Loss: 4.31e-01 in 0.5s.
Epoch 11: Loss: 4.17e-01 in 0.5s.
Epoch 12: Loss: 3.99e-01 in 0.5s.
Epoch 13: Loss: 3.76e-01 in 0.5s.
Epoch 14: Loss: 3.47e-01 in 0.5s.
Epoch 15: Loss: 3.13e-01 in 0.5s.
Epoch 16: Loss: 2.84e-01 in 0.5s.
Epoch 17: Loss: 3.12e-01 in 0.5s.
Epoch 18: Loss: 4.34e-01 in 0.5s.
Epoch 19: Loss: 5.88e-01 in 0.5s.
Epoch 20: Loss: 3.54e-01 in 0.5s.
Epoch 21: Loss: 2.17e-01 in 0.5s.
Epoch 22: Loss: 1.81e-01 in 0.5s.
Epoch 23: Loss: 1.54e-01 in 0.5s.
Epoch 24: Loss: 1.32e-01 in 0.6s.
Epoch 25: Loss: 1.12e-01 in 0.5s.
Epoch 26: Loss: 9.52e-02 in 0.6s.
Epoch 27: Loss: 8.51e-02 in 0.5s.
Epoch 28: Loss: 1.00e-01 in 0.6s.
Epoch 29: Loss: 2.27e-01 in 0.6s.
Epoch 30: Loss: 4.47e-01 in 0.6s.
Epoch 31: Loss: 7.02e-01 in 0.5s.
Epoch 32: Loss: 2.08e-01 in 0.5s.
Epoch 33: Loss: 8.48e-02 in 0.6s.
Epoch 34: Loss: 6.81e-02 in 0.5s.
Epoch 35: Loss: 5.76e-02 in 0.5s.
Epoch 36: Loss: 4.92e-02 in 0.5s.
Epoch 37: Loss: 4.22e-02 in 0.5s.
Epoch 38: Loss: 3.64e-02 in 0.5s.
Epoch 39: Loss: 3.17e-02 in 0.5s.
Epoch 40: Loss: 2.79e-02 in 0.5s.
Epoch 41: Loss: 2.50e-02 in 0.5s.
Epoch 42: Loss: 2.27e-02 in 0.5s.
Epoch 43: Loss: 2.10e-02 in 0.5s.
Epoch 44: Loss: 1.98e-02 in 0.5s.
Epoch 45: Loss: 1.89e-02 in 0.5s.
Epoch 46: Loss: 1.82e-02 in 0.5s.
Epoch 47: Loss: 1.78e-02 in 0.5s.
Epoch 48: Loss: 1.75e-02 in 0.5s.
Epoch 49: Loss: 1.73e-02 in 0.5s.

Big thanks for the test. So, I have sort of β€œbroken” environment… This issue is closed.

Let’s try to fix your environment. Could you create a new and empty virtual environment (e.g. via conda) and reinstall the latest PyTorch binary there?

That’s what I’ve just done, in another directory, and it works fine.

1 Like