Can't trace the model using torch.jit.trace

Can’t trace the model using torch.jit.trace. This is a resnet 101 based segmentation model.
I am using python 3.7, torch 1.8, rtx 3070 8gb.
My code:

Net=FCN.Net(CatDic.CatNum) 
Net.load_state_dict(torch.load('./model.torch', map_location=torch.device('cuda')), strict=False)  
Net.eval()

c = torch.jit.trace(Net, torch.randn(1, 640, 640, 3).cuda())

My neural network structure:

class Net(nn.Module):
    def __init__(self, CatDict):
            super(Net, self).__init__()
            self.Encoder = models.resnet101(pretrained=True)
            self.PSPScales = [1, 1 / 2, 1 / 4, 1 / 8]

            self.PSPLayers = nn.ModuleList()
            for Ps in self.PSPScales:
                self.PSPLayers.append(nn.Sequential(
                    nn.Conv2d(2048, 1024, stride=1, kernel_size=3, padding=1, bias=True)))
            self.PSPSqueeze = nn.Sequential(
                nn.Conv2d(4096, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.Conv2d(512, 512, stride=1, kernel_size=3, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()
            )
            
            self.SkipConnections = nn.ModuleList()
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(1024, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()))
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(512, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(256, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            # ------------------Skip squeeze applied to the (concat of upsample+skip conncection layers)-----------------------------------------------------------------------------
            self.SqueezeUpsample = nn.ModuleList()
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(1024, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()))
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(256 + 512, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(256 + 256, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))


            self.OutLayersList =nn.ModuleList()
            self.OutLayersDict={}
            for f,nm in enumerate(CatDict):
                    self.OutLayersDict[nm]= nn.Conv2d(256, 2, stride=1, kernel_size=3, padding=1, bias=False)
                    self.OutLayersList.append(self.OutLayersDict[nm])

    def forward(self,Images, UseGPU = True, TrainMode=False, FreezeBatchNormStatistics=False):
                RGBMean = [123.68,116.779,103.939]
                RGBStd = [65,65,65]
                if TrainMode:
                        tp=torch.FloatTensor
                else:
                    self.half()
                    tp=torch.HalfTensor
                    self.eval()
                #InpImages = torch.autograd.Variable(torch.from_numpy(Images), requires_grad=False).transpose(2,3).transpose(1, 2).type(torch.FloatTensor)

                InpImages = torch.autograd.Variable(Images, requires_grad=False).transpose(2,3).transpose(1, 2).type(tp)
                if FreezeBatchNormStatistics==True: self.eval()
                if UseGPU:
                    InpImages=InpImages.cuda()
                    self.cuda()
                else:
                    self=self.cpu()
                    self.float()
                    InpImages=InpImages.type(torch.float).cpu()
                for i in range(len(RGBMean)): InpImages[:, i, :, :]=(InpImages[:, i, :, :]-RGBMean[i])/RGBStd[i] # normalize image values
                x=InpImages
                SkipConFeatures=[] # Store features map of layers used for skip connection
                x = self.Encoder.conv1(x)
                x = self.Encoder.bn1(x)
                x = self.Encoder.relu(x)
                x = self.Encoder.maxpool(x)
                x = self.Encoder.layer1(x)
                SkipConFeatures.append(x)
                x = self.Encoder.layer2(x)
                SkipConFeatures.append(x)
                x = self.Encoder.layer3(x)
                SkipConFeatures.append(x)
                x = self.Encoder.layer4(x)
                PSPSize=(x.shape[2],x.shape[3]) # Size of the original features map

                PSPFeatures=[] # Results of various of scaled procceessing
                for i,PSPLayer in enumerate(self.PSPLayers): # run PSP layers scale features map to various of sizes apply convolution and concat the results
                      NewSize=(np.array(PSPSize)*self.PSPScales[i]).astype(np.int)
                      if NewSize[0] < 1: NewSize[0] = 1
                      if NewSize[1] < 1: NewSize[1] = 1

                      y = nn.functional.interpolate(x, tuple(NewSize), mode='bilinear')
                      y = PSPLayer(y)
                      y = nn.functional.interpolate(y, PSPSize, mode='bilinear')
                      PSPFeatures.append(y)
                x=torch.cat(PSPFeatures,dim=1)
                x=self.PSPSqueeze(x)
                for i in range(len(self.SkipConnections)):
                  sp=(SkipConFeatures[-1-i].shape[2],SkipConFeatures[-1-i].shape[3])
                  x=nn.functional.interpolate(x,size=sp,mode='bilinear') #Resize
                  x = torch.cat((self.SkipConnections[i](SkipConFeatures[-1-i]),x), dim=1)
                  x = self.SqueezeUpsample[i](x)

                self.OutLbDict = {}
                
                ret_arr = np.eye(640, 640)
                for nm in self.OutLayersDict:
                  l=self.OutLayersDict[nm](x)
                  l = nn.functional.interpolate(l,size=InpImages.shape[2:4],mode='bilinear') # Resize to original image size
                  tt, Labels = l.max(1)  # Find label per pixel
                  self.OutLbDict[nm] = Labels
                  array = np.asarray(self.OutLbDict[nm].cpu())
                  resx = np.reshape(array, ((array.shape)[2], (array.shape)[1]))
                  ret_arr = list(ret_arr + resx*10)

                return ret_arr

I get an error:

RuntimeError: Tracer cannot infer type of [array([..])]
:Could not infer type of list element: Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type ndarray.

If I remove all numpy arrays from the code, then I get a different error:

C:\anaconda3\lib\site-packages\torch\jit_trace.py in _check_trace(check_inputs, func, traced_func, check_tolerance, strict, force_outplace, is_trace_module, _module_class)
517 diag_info = graph_diagnostic_info()
518 if any(info is not None for info in diag_info):
→ 519 raise TracingCheckError(*diag_info)
520
521

TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%self.1 : __torch__.FCN_NetModel.Net,
		        %Images : Tensor):
		    %2 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="SqueezeUpsample"](%self.1)
		    %3 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="2"](%2)
		    %4 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="SkipConnections"](%self.1)
		    %5 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="2"](%4)

........

The interesting thing is that if I run c = torch.jit.trace (Net, torch.randn (1, 640, 640, 3) .cuda ()) again, the last error does not occur and the tracing is successful. But this traced model doesn’t work. I would be grateful for your help.

As you’ve already explained, the first error is used if numpy arrays are used instead of tensors.
The second one is raised, if your forward pass is data-dependent and could change for different inputs.
Tracing a model would record all operations for the provided input and would not allow to execute conditions inside the model etc. If you want to use conditions, loops, etc. (as is suggested by the error message), you could torch.jit.script the model instead.

2 Likes

As for the first and second errors, you can provide me with an example, even if it is inaccurate, preferably on the code that you see in the question. Thank you very much, I roughly understood the problem, but still doubts and misunderstandings arise.

Based on the graph diff in the error message, the issue seems to be that one invocation of your module by the tracer calls self.SqueezeUpsample[2] and self.SkipConnections[2] but the next does not. But I cannot pinpoint in your code where this might be happening. self.SqueezeUpsample and self.skipConnections are used in for loops, but those loops have a deterministic number of iterations…

Your information is very interesting. Thank you for sharing