I’m seeing a strange behavior in torch::jit::freeze
with precision loss that I’m not understanding, hope someone can clarify. The below code is pasted and ran with torch 1.9.1+cu111
. I have two models that are copies of each other. They are both frozen before running inference.
If I pass the same input to both, the exact same output is produced. However, if I pass the same input twice to the first model, and only once to the second model, the results are only equal up to some precision loss. This behavior does not reproduce if I do not freeze the models. The printed frozen code is the same for both models as well. Any ideas?
import argparse
import torch
from torch import nn
import copy
class BugModule(nn.Module):
def __init__(self):
super().__init__()
self.filter = nn.Conv2d(4, 4, kernel_size=1, padding=0, bias=False, groups=4)
def forward(self, x, y):
cov_xy = self.filter(x*y) - self.filter(x) * self.filter(y)
return cov_xy
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--resolution', type=int, default=[4,4], nargs=2)
args = parser.parse_args()
model1 = BugModule()
model2 = copy.deepcopy(model1)
model1 = model1.to(device='cuda', dtype=torch.float32).eval()
model1 = torch.jit.script(model1)
model1 = torch.jit.freeze(model1)
model2 = model2.to(device='cuda', dtype=torch.float32).eval()
model2 = torch.jit.script(model2)
model2 = torch.jit.freeze(model2)
w, h = args.resolution
x = torch.randn((1, 4, h, w), device='cuda', dtype=torch.float32)
y = torch.randn((1, 4, h, w), device='cuda', dtype=torch.float32)
with torch.no_grad():
z1 = model1(x, y)
# z1 = model1(x, y) # uncomment this to see precision loss
z2 = model2(x, y)
if not torch.equal(z1, z2):
print('FAILED')
else:
print('SUCCESS')