Precision loss with torch::jit::freeze

I’m seeing a strange behavior in torch::jit::freeze with precision loss that I’m not understanding, hope someone can clarify. The below code is pasted and ran with torch 1.9.1+cu111. I have two models that are copies of each other. They are both frozen before running inference.

If I pass the same input to both, the exact same output is produced. However, if I pass the same input twice to the first model, and only once to the second model, the results are only equal up to some precision loss. This behavior does not reproduce if I do not freeze the models. The printed frozen code is the same for both models as well. Any ideas?

import argparse
import torch
from torch import nn
import copy


class BugModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.filter = nn.Conv2d(4, 4, kernel_size=1, padding=0, bias=False, groups=4)
        
    def forward(self, x, y):
        cov_xy = self.filter(x*y) - self.filter(x) * self.filter(y)
        return cov_xy
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--resolution', type=int, default=[4,4], nargs=2)
    args = parser.parse_args()

    model1 = BugModule()
    model2 = copy.deepcopy(model1)

    model1 = model1.to(device='cuda', dtype=torch.float32).eval()
    model1 = torch.jit.script(model1)
    model1 = torch.jit.freeze(model1)

    model2 = model2.to(device='cuda', dtype=torch.float32).eval()
    model2 = torch.jit.script(model2)
    model2 = torch.jit.freeze(model2)

    w, h = args.resolution
    x = torch.randn((1, 4, h, w), device='cuda', dtype=torch.float32)
    y = torch.randn((1, 4, h, w), device='cuda', dtype=torch.float32)

    with torch.no_grad():
      z1 = model1(x, y)
      # z1 = model1(x, y) # uncomment this to see precision loss
      z2 = model2(x, y)

      if not torch.equal(z1, z2):
        print('FAILED')
      else:
        print('SUCCESS')

What you are seeing is a combination of two effects:

  • Floating point arithmetic has it that it is only approximate. The difference you are seeing is within the range of accepted variation for single precision (1e-7ish), so neither of the two results is “incorrect”. Typically one would use torch.allclose or something similar rather than torch.equal to check whether the results match.

  • In your case, the JIT fuser - activated only in the second run of the module - has a different way of combining the three outputs of filter to compute cov_xy (the multiplication and subtraction), generating a bespoke fused kernel for it. This computes the slightly different result compared to the usual operators for * and -. Very likely, the exact order of addition differs. Note that PyTorch makes no guarantees here that any particular version runs, not even when you ask it to be deterministic (which this bit would be after two runs by default, when the JIT uses the fused kernels).

Best regards

Thomas