How to avoid that traced function does not match the corresponding output of the Python function?

(Nuuo ) #1

hey guys, I try to load my pytorch in c++. I use

traced_script_module = torch.jit.trace(net, (dammy_input1, dammy_input2))

to convert my model(my model has two inputs). After I run this, there raises a

TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function.

some values are not within the tolerance of 1e-5, actually, they are much different, like (-0.242 vs. 3.92).
I recheck my model, my model structure is like below:

conv1
conv2
correlation
some convs and upsamples
...
return

and I find that the dismatch of traced function and the python function is caused by module correlation, I check that the result of conv2 just before the correlation and find that they are the under the tolerance of 1e-5. So, the correlation module is the problem. the correlation module is like below

class Corr(nn.Module):
    def __init__(self, ...):
        # some parameters init here
        ...
    def corr_func(self, fL, fR):
        return (fL*fR).sum(dim=1);
    
    def forward(self, img1, img2):
        num, bchannels, bheight, bwidth = img1.shape[0], img1.shape[1], img1.shape[2], img1.shape[3]
        top = torch.zeros(num, 27, 192, 192).type_as(img1.data)
        x_shift = - self.neighborhood_grid_radius
        
        sumelems = self.kernel_size * self.kernel_size * bchannels
        for ch in range(0,  self.top_channels):
            s2o = int(( ch % self.neighborhood_grid_width_ + x_shift ) * self.stride2) + self.pad_shift
            if s2o > 0 :
                top[:,ch,:, :-s2o] = self.corr_func(img1[:, :, :, :-s2o] , img2[:, :, :, s2o:])
            elif s2o < 0:
                top[:, ch, :, -s2o:] = self.corr_func(img1[:, :, :, -s2o:], img2[:, :, :, :s2o])
            else:
                top[:, ch, :, :]  = self.corr_func(img1, img2)
                
        top = top / sumelems
        return top

the basic idea of corr is get a patch from img1 and another corresponding patch from img2, multiply them and get the sum. I think the correlation is quite similar to conv. Why the torch.jit.trace got a such difference result between python and traced functions. and How should I modify the correlation module code the avoid the dismatch?
Any suggestion will be appreciated! thanks a lot!