Discrepancies Between Compiled and Non-Compiled Models with Convolutional Layers in PyTorch

Hello everyone,

I’ve been experimenting with PyTorch’s torch.compile function (with mode=‘max-autotune’) to optimize my model. I’ve observed some discrepancies in the outputs of the network when comparing the compiled version to the non-compiled version, but only when convolutional layers are involved. On networks with nothing but linear layers I observe the exact same outputs.

Is this expected behavior? Does torch.compile in max-autotune mode have less precision? I actually observe a substantial speedup when using the compiled version (on an RTX 3080) but I worry that it may come with reduced precision and I hesitate to embed it to my code. I use the latest stable release of torch: 2.4.1.

Here is a minimal example:

import torch
import torch.nn as nn
from torch.testing import assert_allclose

class ConvLinearModel(nn.Module):
    def __init__(self):
        super(ConvLinearModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def test_accuracy_over_runs(num_runs=100):
    discrepancies = 0
    model = ConvLinearModel()
    compiled_model = torch.compile(model, mode='max-autotune')

    for i in range(num_runs):
        torch.manual_seed(42 + i)
        input_data = torch.randn(10, 3, 8, 8)
        non_compiled_output = model(input_data)
        compiled_output = compiled_model(input_data)

        try:
            assert_allclose(compiled_output, non_compiled_output, rtol=1e-5, atol=1e-8)
        except AssertionError:
            discrepancies += 1

    return discrepancies

# Testing
discrepancies = test_accuracy_over_runs(100)
print(f"Number of discrepancies found: {discrepancies}")

When I run your script on a nightly (on an H100 - I don’t have an RTX to test on), I don’t get any discrepancies.

Just a few ideas/questions:

(1) have you tried on a recent nightly?

(2) You specified an atol of atol=1e-8, which is quite small. In general, the compiler can fuse operations that can cause slight changes in numerics. What is the biggest atol/rtol numeric difference you’re seeing?

(3) Does this only repro with mode='max-autotune? Or does it also repro when you use vanilla inductor