hey guys, I try to load my pytorch in c++. I use
traced_script_module = torch.jit.trace(net, (dammy_input1, dammy_input2))
to convert my model(my model has two inputs). After I run this, there raises a
TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function.
some values are not within the tolerance of 1e-5, actually, they are much different, like (-0.242 vs. 3.92).
I recheck my model, my model structure is like below:
conv1 conv2 correlation some convs and upsamples ... return
and I find that the dismatch of traced function and the python function is caused by module
correlation, I check that the result of conv2 just before the correlation and find that they are the under the tolerance of 1e-5. So, the correlation module is the problem. the correlation module is like below
class Corr(nn.Module): def __init__(self, ...): # some parameters init here ... def corr_func(self, fL, fR): return (fL*fR).sum(dim=1); def forward(self, img1, img2): num, bchannels, bheight, bwidth = img1.shape, img1.shape, img1.shape, img1.shape top = torch.zeros(num, 27, 192, 192).type_as(img1.data) x_shift = - self.neighborhood_grid_radius sumelems = self.kernel_size * self.kernel_size * bchannels for ch in range(0, self.top_channels): s2o = int(( ch % self.neighborhood_grid_width_ + x_shift ) * self.stride2) + self.pad_shift if s2o > 0 : top[:,ch,:, :-s2o] = self.corr_func(img1[:, :, :, :-s2o] , img2[:, :, :, s2o:]) elif s2o < 0: top[:, ch, :, -s2o:] = self.corr_func(img1[:, :, :, -s2o:], img2[:, :, :, :s2o]) else: top[:, ch, :, :] = self.corr_func(img1, img2) top = top / sumelems return top
the basic idea of corr is get a patch from
img1 and another corresponding patch from
img2, multiply them and get the sum. I think the correlation is quite similar to
conv. Why the
torch.jit.trace got a such difference result between python and traced functions. and How should I modify the correlation module code the avoid the dismatch?
Any suggestion will be appreciated! thanks a lot!