Unsupported value kind: Tensor

Hello everyone! I have started recently using torchscript, and there is this error that I don’t know how to solve. Basically I am trying to script a model, and specifically this is my constructor, so you have some reference:

def __init__(
        self,
        dim: int,
        emb: str = "sin",
        hidden_scale: float = 4.0,
        num_heads: int = 8,
        num_layers: int = 6,
        cross_first: bool = False,
        dropout: float = 0.0,
        max_positions: int = 1000,
        norm_in: bool = True,
        norm_in_group: bool = False,
        group_norm: int = False,
        norm_first: bool = False,
        norm_out: bool = False,
        max_period: float = 10000.0,
        weight_decay: float = 0.0,
        lr: tp.Optional[float] = None,
        layer_scale: bool = False,
        gelu: bool = True,
        sin_random_shift: int = 0,
        weight_pos_embed: float = 1.0,
        cape_mean_normalize: bool = True,
        cape_augment: bool = True,
        cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
        sparse_self_attn: bool = False,
        sparse_cross_attn: bool = False,
        mask_type: str = "diag",
        mask_random_seed: int = 42,
        sparse_attn_window: int = 500,
        global_window: int = 50,
        auto_sparsity: bool = False,
        sparsity: float = 0.95,
    ):
        super().__init__()
        """
        """
        assert dim % num_heads == 0

        hidden_dim = int(dim * hidden_scale)

        self.num_layers = num_layers
        # classic parity = 1 means that if idx%2 == 1 there is a
        # classical encoder else there is a cross encoder
        self.classic_parity = 1 if cross_first else 0
        self.emb = emb
        self.max_period:float = max_period
        self.weight_decay = weight_decay
        self.weight_pos_embed = weight_pos_embed
        self.sin_random_shift:int = sin_random_shift
        # Always define these attributes with default values
        self.cape_mean_normalize: bool = False
        self.cape_augment: bool = False
        self.cape_glob_loc_scale: list = [5000.0, 1.0, 1.4]
        self.position_embeddings: nn.Module = nn.Identity()  # Default value

        if emb == "cape":
            self.cape_mean_normalize = cape_mean_normalize
            self.cape_augment = cape_augment
            self.cape_glob_loc_scale = cape_glob_loc_scale
        if emb == "scaled":
            self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)

        self.lr = lr

        activation: tp.Any = F.gelu if gelu else F.relu

        self.norm_in: nn.Module
        self.norm_in_t: nn.Module
        if norm_in:
            self.norm_in = nn.LayerNorm(dim)
            self.norm_in_t = nn.LayerNorm(dim)
        elif norm_in_group:
            self.norm_in = MyGroupNorm(int(norm_in_group), dim)
            self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
        else:
            self.norm_in = nn.Identity()
            self.norm_in_t = nn.Identity()

        # spectrogram layers
        self.layers = nn.ModuleList()
        # temporal layers
        self.layers_t: nn.ModuleList = nn.ModuleList()

        kwargs_common = {
            "d_model": dim,
            "nhead": num_heads,
            "dim_feedforward": hidden_dim,
            "dropout": dropout,
            "activation": activation,
            "group_norm": group_norm,
            "norm_first": norm_first,
            "norm_out": norm_out,
            "layer_scale": layer_scale,
            "mask_type": mask_type,
            "mask_random_seed": mask_random_seed,
            "sparse_attn_window": sparse_attn_window,
            "global_window": global_window,
            "sparsity": sparsity,
            "auto_sparsity": auto_sparsity,
            "batch_first": True,
        }

        kwargs_classic_encoder = dict(kwargs_common)
        kwargs_classic_encoder.update({
            "sparse": sparse_self_attn,
        })
        kwargs_cross_encoder = dict(kwargs_common)
        kwargs_cross_encoder.update({
            "sparse": sparse_cross_attn,
        })

        for idx in range(num_layers):
            if idx % 2 == self.classic_parity:

                self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
                self.layers_t.append(
                    MyTransformerEncoderLayer(**kwargs_classic_encoder)
                )

            else:
                self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))

                self.layers_t.append(
                    CrossTransformerEncoderLayer(**kwargs_cross_encoder)
                )

While this is my forward method:

def forward(self, x, xt):
        # For x:
        B, C, Fr, T1 = x.shape
        pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period)  # (1, C, Fr, T1)
        pos_emb_2d = rearrange_b_cfrt1_to_b_t1fr_c(pos_emb_2d)
        x = rearrange_b_cfrt1_to_b_t1fr_c(x)
        x = self.norm_in(x)
        x = x + self.weight_pos_embed * pos_emb_2d

        # For xt:
        B, C, T2 = xt.shape
        xt = rearrange_b_c_t2_to_b_t2_c(xt)
        pos_emb = self._get_pos_embedding(int(T2), int(B), int(C), x.device)
        # Assume pos_emb shape is (T2, B, C); rearrange to (B, T2, C)
        pos_emb = pos_emb.permute(1, 0, 2)
        xt = self.norm_in_t(xt)
        xt = xt + self.weight_pos_embed * pos_emb

        # Processing layers:
        for index, layer in enumerate(self.layers):
            if index % 2 == self.classic_parity:
                x = layer(x)
                xt = self.layers_t[index](xt)
            else:
                old_x = x
                x = layer(x, xt)
                xt = self.layers_t[index](xt, old_x)

        x = rearrange_b_t1fr_c_to_b_c_fr_t1(x, T1)
        xt = rearrange_b_t2_c_to_b_c_t2(xt)
        return x, xt

I know that the problem is here in the for loop (I commented various parts to determine who caused the problem and this is the problem), in particular when calling layer(x) and so on:

for index, layer in enumerate(self.layers):
            if index % 2 == self.classic_parity:
                x = layer(x)
                xt = self.layers_t[index](xt)
            else:
                old_x = x
                x = layer(x, xt)
                xt = self.layers_t[index](xt, old_x)

What is wrong here, could you help me in a way to solve this problem? I have no idea how to do it.

Another thing, if I comment out the lines x=layer(x) and x=layer(x, xt) I get another problem:
Expected integer literal for index but got a variable or non-integer. ModuleList/Sequential indexing is only supported with integer literals. For example, ‘i = 4; self.layersi’ will fail because i is not a literal. Enumeration is supported, e.g. ‘for index, v in enumerate(self): out = v(inp)’:

        if index % 2 == self.classic_parity:
            # x = layer(x)
            xt = self.layers_t[index](xt)
                 ~~~~~~~~~~~~~~~~~~~~ <--- HERE
        else:
            old_x = x

If you need more info about the other class, for example the transformer encoder layer or cross transformer encoder layer I’ll send you! Your help will be much appreciated!