RuntimeError: ENTER is not supported in mobile module

I am trying to build the ptl weights for one of the simple SRE model. I tried to simplify the model so that it doesn’t contain any complex operations.

But it’s still throwing this ambiguous error.

RecursiveScriptModule(original_name=MainModel)
Traceback (most recent call last):
  File "melspec_test.py", line 90, in <module>
    optimized_model._save_for_lite_interpreter("vgg_sre.ptl")
  File "C:\ProgramData\Anaconda3\envs\py38\lib\site-packages\torch\jit\_script.py", line 707, in _save_for_lite_interpreter
    return self._c._save_for_mobile(*args, **kwargs)
RuntimeError: ENTER is not supported in mobile module.

My code:

import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.utils.mobile_optimizer import optimize_for_mobile


class MainModel(nn.Module):
    def __init__(self, nOut = 1024):
        super(MainModel, self).__init__()
        
        self.netcnn = nn.Sequential(
            nn.Conv2d(1, 96, kernel_size=(5,7), stride=(1,2), padding=(2,2)),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(1,3), stride=(1,2)),

            nn.Conv2d(96, 256, kernel_size=(5,5), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),

            nn.Conv2d(256, 384, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),

            nn.Conv2d(256, 512, kernel_size=(4,1), padding=(0,0)),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
        )

        self.sap_linear = nn.Linear(512, 512)
        self.attention = self.new_parameter(512, 1)
        out_dim = 512


        self.fc = nn.Linear(out_dim, nOut)

    def new_parameter(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out
        
    def forward(self, x):

        with torch.no_grad():
            x = x.log()
            x = x.unsqueeze(1)

        x = self.netcnn(x)

        x = x.permute(0, 2, 1, 3)
        x = x.squeeze(dim=1).permute(0, 2, 1)  # batch * L * D
        h = torch.tanh(self.sap_linear(x))
        w = torch.matmul(h, self.attention).squeeze(dim=2)
        w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
        x = torch.sum(x * w, dim=1)

        x = self.fc(x)

        return x

model = MainModel()
model.eval()
x = torch.rand(1, 40, 200)
traced_cell = torch.jit.trace(model, x)


quantized_model = torch.quantization.quantize_dynamic(
    model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_model = torch.jit.script(quantized_model)
optimized_model = optimize_for_mobile(scripted_model)

print(optimized_model)

optimized_model._save_for_lite_interpreter("vgg_sre.ptl")

I am using pytorch torch-1.11.0.dev20220108+cpu (nightly) on windows.

I am not sure, for which operations it’s failing, there’s no clear error message.

it says “ENTER” instruction is not supported in mobile:

I don’t know which op triggers it (maybe with torch.no_grad?). cc @Martin_Yuan

I think this line might be an issue. ENTER is used to resolve context and is not supported with mobile interpreter yet. The with statement basically creates a runtime context that allows you to run a group of statements under the control of a context manager.

1 Like