Hello everyone! I have started recently using torchscript, and there is this error that I don’t know how to solve. Basically I am trying to script a model, and specifically this is my constructor, so you have some reference:
def __init__(
self,
dim: int,
emb: str = "sin",
hidden_scale: float = 4.0,
num_heads: int = 8,
num_layers: int = 6,
cross_first: bool = False,
dropout: float = 0.0,
max_positions: int = 1000,
norm_in: bool = True,
norm_in_group: bool = False,
group_norm: int = False,
norm_first: bool = False,
norm_out: bool = False,
max_period: float = 10000.0,
weight_decay: float = 0.0,
lr: tp.Optional[float] = None,
layer_scale: bool = False,
gelu: bool = True,
sin_random_shift: int = 0,
weight_pos_embed: float = 1.0,
cape_mean_normalize: bool = True,
cape_augment: bool = True,
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
sparse_self_attn: bool = False,
sparse_cross_attn: bool = False,
mask_type: str = "diag",
mask_random_seed: int = 42,
sparse_attn_window: int = 500,
global_window: int = 50,
auto_sparsity: bool = False,
sparsity: float = 0.95,
):
super().__init__()
"""
"""
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.num_layers = num_layers
# classic parity = 1 means that if idx%2 == 1 there is a
# classical encoder else there is a cross encoder
self.classic_parity = 1 if cross_first else 0
self.emb = emb
self.max_period:float = max_period
self.weight_decay = weight_decay
self.weight_pos_embed = weight_pos_embed
self.sin_random_shift:int = sin_random_shift
# Always define these attributes with default values
self.cape_mean_normalize: bool = False
self.cape_augment: bool = False
self.cape_glob_loc_scale: list = [5000.0, 1.0, 1.4]
self.position_embeddings: nn.Module = nn.Identity() # Default value
if emb == "cape":
self.cape_mean_normalize = cape_mean_normalize
self.cape_augment = cape_augment
self.cape_glob_loc_scale = cape_glob_loc_scale
if emb == "scaled":
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
self.lr = lr
activation: tp.Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
self.norm_in_t: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
self.norm_in_t = nn.LayerNorm(dim)
elif norm_in_group:
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
else:
self.norm_in = nn.Identity()
self.norm_in_t = nn.Identity()
# spectrogram layers
self.layers = nn.ModuleList()
# temporal layers
self.layers_t: nn.ModuleList = nn.ModuleList()
kwargs_common = {
"d_model": dim,
"nhead": num_heads,
"dim_feedforward": hidden_dim,
"dropout": dropout,
"activation": activation,
"group_norm": group_norm,
"norm_first": norm_first,
"norm_out": norm_out,
"layer_scale": layer_scale,
"mask_type": mask_type,
"mask_random_seed": mask_random_seed,
"sparse_attn_window": sparse_attn_window,
"global_window": global_window,
"sparsity": sparsity,
"auto_sparsity": auto_sparsity,
"batch_first": True,
}
kwargs_classic_encoder = dict(kwargs_common)
kwargs_classic_encoder.update({
"sparse": sparse_self_attn,
})
kwargs_cross_encoder = dict(kwargs_common)
kwargs_cross_encoder.update({
"sparse": sparse_cross_attn,
})
for idx in range(num_layers):
if idx % 2 == self.classic_parity:
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
self.layers_t.append(
MyTransformerEncoderLayer(**kwargs_classic_encoder)
)
else:
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
self.layers_t.append(
CrossTransformerEncoderLayer(**kwargs_cross_encoder)
)
While this is my forward method:
def forward(self, x, xt):
# For x:
B, C, Fr, T1 = x.shape
pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period) # (1, C, Fr, T1)
pos_emb_2d = rearrange_b_cfrt1_to_b_t1fr_c(pos_emb_2d)
x = rearrange_b_cfrt1_to_b_t1fr_c(x)
x = self.norm_in(x)
x = x + self.weight_pos_embed * pos_emb_2d
# For xt:
B, C, T2 = xt.shape
xt = rearrange_b_c_t2_to_b_t2_c(xt)
pos_emb = self._get_pos_embedding(int(T2), int(B), int(C), x.device)
# Assume pos_emb shape is (T2, B, C); rearrange to (B, T2, C)
pos_emb = pos_emb.permute(1, 0, 2)
xt = self.norm_in_t(xt)
xt = xt + self.weight_pos_embed * pos_emb
# Processing layers:
for index, layer in enumerate(self.layers):
if index % 2 == self.classic_parity:
x = layer(x)
xt = self.layers_t[index](xt)
else:
old_x = x
x = layer(x, xt)
xt = self.layers_t[index](xt, old_x)
x = rearrange_b_t1fr_c_to_b_c_fr_t1(x, T1)
xt = rearrange_b_t2_c_to_b_c_t2(xt)
return x, xt
I know that the problem is here in the for loop (I commented various parts to determine who caused the problem and this is the problem), in particular when calling layer(x) and so on:
for index, layer in enumerate(self.layers):
if index % 2 == self.classic_parity:
x = layer(x)
xt = self.layers_t[index](xt)
else:
old_x = x
x = layer(x, xt)
xt = self.layers_t[index](xt, old_x)
What is wrong here, could you help me in a way to solve this problem? I have no idea how to do it.
Another thing, if I comment out the lines x=layer(x) and x=layer(x, xt) I get another problem:
Expected integer literal for index but got a variable or non-integer. ModuleList/Sequential indexing is only supported with integer literals. For example, ‘i = 4; self.layersi’ will fail because i is not a literal. Enumeration is supported, e.g. ‘for index, v in enumerate(self): out = v(inp)’:
if index % 2 == self.classic_parity:
# x = layer(x)
xt = self.layers_t[index](xt)
~~~~~~~~~~~~~~~~~~~~ <--- HERE
else:
old_x = x
If you need more info about the other class, for example the transformer encoder layer or cross transformer encoder layer I’ll send you! Your help will be much appreciated!