torch._dynamo.exc.Unsupported: call_method NNModuleVariable() _sa_block [TensorVariable(), LazyVariableTracker(), LazyVariableTracker()] {}

I’m encountering an issue when trying to convert my PyTorch model to TensorFlow Lite using the ai-edge-torch library. The error seems to occur when _sa_block is called from torch.nn.TransformerEncoderLayer. It appears that torch.export.export is being used internally where the issue arises. Below is the portion of my model causing the issue.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation=F.relu,
        group_norm=0,
        norm_first=False,
        norm_out=False,
        layer_norm_eps=1e-5,
        layer_scale=False,
        init_values=1e-4,
        device=None,
        dtype=None,
        sparse=False,
        mask_type="diag",
        mask_random_seed=42,
        sparse_attn_window=500,
        global_window=50,
        auto_sparsity=False,
        sparsity=0.95,
        batch_first=False,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            batch_first=batch_first,
            norm_first=norm_first,
            device=device,
            dtype=dtype,
        )
       
        if group_norm:
            self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
            self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)

        self.norm_out = None
        if self.norm_first & norm_out:
            self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
        self.gamma_1 = (
            LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
        )
        self.gamma_2 = (
            LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
        )

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """
        if batch_first = False, src shape is (T, B, C)
        the case where batch_first=True is not covered
        """
        device = src.device
        x = src
        T, B, C = x.shape
        if self.norm_first:
            x=self.norm1(x)
            x = x + self.gamma_1(
                self._sa_block(x, src_mask, src_key_padding_mask)
            )
            x = x + self.gamma_2(self._ff_block(self.norm2(x)))

            if self.norm_out:
                x = self.norm_out(x)
        else:
            x = self.norm1(
                x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
            )
            x = self.norm2(x + self.gamma_2(self._ff_block(x)))

        return x

When calling _sa_block from torch.nn.TransformerEncoderLayer getting issue:
torch._dynamo.exc.Unsupported: call_method NNModuleVariable() _sa_block [TensorVariable(), LazyVariableTracker(), LazyVariableTracker()] {}

The error occurs specifically at the _sa_block call within the forward method.

I believe it is related to the usage of torch.export.export, which is happening under the hood.

Thank you for your attention to this matter. I look forward to your response and any guidance you can provide.