Undefined value output with torch.cat and torch.jit.save

Hi all,
I was trying to save my model on kaggle notebook using torch.jit.save but I got undefined value output.
This is my model

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size.
        patch_size (int | tuple(int)): Patch size.
        in_chans (int): Number of input channels.
        num_classes (int): Number of classes for classification head.
        embed_dim (int): Embedding dimension.
        depths (tuple(int)): Depth of Swin Transformer layers.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate.
        drop_path_rate (float): Stochastic depth rate.
        norm_layer (nn.Module): normalization layer.
        ape (bool): If True, add absolute position embedding to the patch embedding.
        patch_norm (bool): If True, add normalization after patch embedding.
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_dense_prediction=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution
        
        self.linear = nn.Linear(1000,64)

        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        # Region prediction head
        self.use_dense_prediction = use_dense_prediction
        if self.use_dense_prediction: self.head_dense = None


        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        # todo: to be implemented
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x_region = self.norm(x)  # B L C
        x = self.avgpool(x_region.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x


    def forward_feature_maps(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x_grid = self.norm(x)  # B L C
        x = self.avgpool(x_grid.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)

        return x, x_grid


    def forward(self, x):
        # convert to list
        if not isinstance(x, list):
            x = [x]
        # Perform forward pass separately on each resolution input.
        # The inputs corresponding to a single resolution are clubbed and single
        # forward is run on the same resolution inputs. Hence we do several
        # forward passes = number of different resolutions used. We then
        # concatenate all the output features.

        # When region level prediction task is used, the network output four variables:
        # self.head(output_cls):       view-level prob vector
        # self.head_dense(output_fea): regioin-level prob vector
        # output_fea:                  region-level feature map (grid features)
        # npatch:                      number of patches per view
        
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)

        start_idx = 0
        for end_idx in idx_crops:
            _out = self.forward_features(torch.cat(x[start_idx: end_idx]))
            if start_idx == 0:
                output = _out
            else:
                output = torch.cat((output, _out), dim=0)
            start_idx = end_idx
            # Run the head forward on the concatenated features.
        return self.head(output) #self.linear
student = SwinTransformer(patch_size=4,dim_embed=96, depths=[2,2,6,2],num_heads=[3,6,12,24], windows_size=7,
                      mlp_ratio=4,qkv_bias=True,drop_rate=0.0,attn_drop_rate=0.0,drop_path_rate=0.2,use_ape=False,
                        patch_norm=True)

I save the model using torch.jit:

saved_model = torch.jit.script(student)
saved_model.save('saved_model.pt')

and this is the bug:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_17/3296653862.py in <module>
      1 from zipfile import ZipFile
      2 
----> 3 saved_model = torch.jit.script(student)
      4 saved_model.save('saved_model.pt')
      5 with ZipFile('submission.zip','w') as zip:

/opt/conda/lib/python3.7/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1264         obj = call_prepare_scriptable_func(obj)
   1265         return torch.jit._recursive.create_script_module(
-> 1266             obj, torch.jit._recursive.infer_methods_to_compile
   1267         )
   1268 

/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    452     if not is_tracing:
    453         AttributeTypeIsSupportedChecker().check(nn_module)
--> 454     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    455 
    456 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    518     # Compile methods if necessary
    519     if concrete_type not in concrete_type_store.methods_compiled:
--> 520         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    521         # Create hooks after methods to ensure no name collisions between hooks and methods.
    522         # If done before, hooks can overshadow methods that aren't exported.

/opt/conda/lib/python3.7/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    369     property_rcbs = [p.resolution_callback for p in property_stubs]
    370 
--> 371     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    372 
    373 def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):

RuntimeError: 
undefined value output:
  File "/tmp/ipykernel_17/2594581719.py", line 183
                output = _out
            else:
                output = torch.cat((output, _out), dim=0)
                                    ~~~~~~ <--- HERE
            start_idx = end_idx
            # Run the head forward on the concatenated features.

Hope you guys can help me. Thanks!!!

Your forward pass does not define output in the else branch and fails:

    def forward(self, x):
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)

        start_idx = 0
        for end_idx in idx_crops:
            _out = self.forward_features(torch.cat(x[start_idx: end_idx]))
            if start_idx == 0:
                output = _out
            else:
                output = torch.cat((output, _out), dim=0)
            start_idx = end_idx

While output is assigned to _out if start_idx is 0, it’s undefined in the else branch since you never initialize it before this branch is used.

Thank you for your help! But the output will always be initialized in the if branch first, then the else branch might run, I still can perfectly train the model. Why does torch.jit.save need defining the output in the else branch?