Unable to reproduce SegFormer code

Hello, I am working on a demo semantic segmentation project using SegFormer, I am implementing the code from scratch using the original codebase by tweaking the parameters.

I am not able to train the model properly, when I logged the model to tensorboard, I noticed not all transformer blocks were appearing and also there were many breaks in the computation graph

class MixVisionTransformer(nn.Module):
    ### MixVisionTransformer

        in_channels (int): input dimension, defaults to 3
        num_classes (int): number of classes, defaults to 1000
        depths (list): depth of each stage, defaults to [3, 4, 6, 3]
        embed_dims (list): embedding dimension of each stage, defaults to [64, 128, 256, 512]
        num_heads (list): number of attention heads of each stage, defaults to [1, 2, 4, 8]
        mlp_ratio (float): ratio of mlp hidden dim to embedding dim, defaults to 4.0
        qkv_bias (bool): if True, add a learnable bias to query, key, value, defaults to True
        qk_scale (float): override default qk scale of head_dim ** -0.5 if set, defaults to None
        drop_rate (float): dropout rate, defaults to 0.
        attn_drop_rate (float): attention dropout rate, defaults to 0.
        drop_path_rate (float): stochastic depth rate

        torch.Tensor: list of output tensor

    def __init__(self, in_channel:int=3, image_size:int=256, patch_size:int=4, num_classes:int=512, embed_dims:list=[64, 128, 256, 512],
                 num_heads:list=[1, 2, 4, 8], mlp_ratio:float=4., qkv_bias:bool=False, qk_scale:float=0.,
                 drop_rate:float=0., attn_drop_rate:float=0., drop_path_rate:float=0.1, norm_layer:nn.Module=nn.LayerNorm,
                 depths:list=[3, 4, 6, 3], sr_ratios:list=[8, 4, 2, 1]) -> None:
        super(MixVisionTransformer, self).__init__()

        self.projection_dim = num_classes
        self.image_size = image_size
        self.embed_dims = embed_dims
        self.num_heads = num_heads

        self.patch_embed1 = OverlappedPatchEmbedding(patch_size=7, in_channels=in_channel, stride=4, embed_dim=self.embed_dims[0])
        self.patch_embed2 = OverlappedPatchEmbedding(patch_size=3, in_channels=self.embed_dims[0], stride=2, embed_dim=self.embed_dims[1])
        self.patch_embed3 = OverlappedPatchEmbedding(patch_size=3, in_channels=self.embed_dims[1], stride=2, embed_dim=self.embed_dims[2])
        self.patch_embed4 = OverlappedPatchEmbedding(patch_size=3, in_channels=self.embed_dims[2], stride=2, embed_dim=self.embed_dims[3])

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        self.transformerblock1 = nn.ModuleList([
            Transformer(in_channels=self.embed_dims[0], num_heads=self.num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                        qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i], norm_layer=nn.LayerNorm, sr_ratio=sr_ratios[0], view_attention=True)
                        for i in range(depths[0])
        self.norm1 = norm_layer(self.embed_dims[0])

        cur += depths[0]
        self.transformerblock2 = nn.ModuleList([
            Transformer(in_channels=self.embed_dims[1], num_heads=self.num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                        qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i], norm_layer=nn.LayerNorm, sr_ratio=sr_ratios[1], view_attention=True)
                        for i in range(depths[1])
        self.norm2 = norm_layer(self.embed_dims[1])

        cur += depths[1]
        self.transformerblock3 = nn.ModuleList([
            Transformer(in_channels=self.embed_dims[2], num_heads=self.num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                        qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i], norm_layer=nn.LayerNorm, sr_ratio=sr_ratios[2], view_attention=True)
                        for i in range(depths[2])
        self.norm3 = norm_layer(self.embed_dims[2])

        cur += depths[2]
        self.transformerblock4 = nn.ModuleList([
            Transformer(in_channels=self.embed_dims[3], num_heads=self.num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                        qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+i], norm_layer=nn.LayerNorm, sr_ratio=sr_ratios[3], view_attention=True)
                        for i in range(depths[3])
        self.norm4 = norm_layer(self.embed_dims[3])


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:

    def forward(self, x:torch.Tensor) -> list:
        Forward pass of the MixVisionTransformer

            x (torch.Tensor): input tensor, shape (B, C, H, W)

            out (list): list of different spatial tensors.

        B = x.shape[0]
        outs = []

        x, H, W = self.patch_embed1(x)
        for i, blk in enumerate(self.transformerblock1):
            x = blk(x, H, W)
        x = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x, H, W = self.patch_embed2(x)
        for i, blk in enumerate(self.transformerblock2):
            x = blk(x, H, W)
        x = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x, H, W = self.patch_embed3(x)
        for i, blk in enumerate(self.transformerblock3):
            x = blk(x, H, W)
        x = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x, H, W = self.patch_embed4(x)
        for i, blk in enumerate(self.transformerblock4):
            x = blk(x, H, W)
        x = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        return outs

This is the code for the transformer encoder.

I am not able to train the model properly

By this do you mean that the loss is not decreasing or something else

when I logged the model to tensorboard, I noticed not all transformer blocks were appearing and also there were many breaks in the computation graph

Are you able to run tensorboard using the original code base?

Hi @soulitzer, i figured out my mistake for the tensorboard issue, the transformer blocks were not named properly.

As for the training issue, it still persists. The loss decreases ever so slightly with each epoch. I am using IoU loss with weighted cross entropy loss, but the prediction mask, which I am logging after every certain epoch, always outputs the same mask (if I don’t ignore any channel during the calculation, it outputs a mask with class 0 (background) and if I ignore the class 0, then it is class 1), irrespective of an image.