Output mismatch with channels_last_3d?

Hi there,

I am trying to implement a convnextv2 model in 3D. I would like the model to support torch.channels_last_3d memory format and torch.compile().

I think I am really close but I can not get the input/output to match after conversion. Any ideas?

from types import SimpleNamespace
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.layers import trunc_normal_, DropPath

###########
## Utils ##
###########


class LayerNorm3d(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

class GRN3d(nn.Module):
    """ 3D GRN (Global Response Normalization) layer
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1,2,3), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


############
## Blocks ##
############

class Block(nn.Module):
    """ ConvNeXtV2 Block.
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
    """
    def __init__(self, dim, drop_path=0.):
        super().__init__()
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm3d(dim, data_format="channels_last")
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.grn = GRN3d(4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1) # (N, D, C, H, W) -> (N, D, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        x = x.permute(0, 4, 1, 2, 3) # (N, D, H, W, C) -> (N, C, D, H, W)

        x = input + self.drop_path(x)
        return x

class ConvNeXtV2Encoder3d(nn.Module):
    """ 
    ConvNeXt V2
        
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 700
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    def __init__(
        self, 
        cfg: SimpleNamespace,
        in_chans: int = 3,
        drop_path_rate: float = 0.0,
        head_init_scale: float = 1.0,
        num_classes: float = 700,
    ):
        super().__init__()
        self.cfg= cfg

        # Backbone init cfg
        bb= self.cfg.backbone
        backbone_cfg = {
            "atto":  ([2, 2, 6, 2], [40, 80, 160, 320]),
            "femto": ([2, 2, 6, 2], [48, 96, 192, 384]),
            "pico":  ([2, 2, 6, 2], [64, 128, 256, 512]),
            "nano":  ([2, 2, 8, 2], [80, 160, 320, 640]),
            "tiny":  ([3, 3, 9, 3], [96, 192, 384, 768]),
            "base":  ([3, 3, 27, 3], [128, 256, 512, 1024]),
            "large": ([3, 3, 27, 3], [192, 384, 768, 1536]),
            "huge":  ([3, 3, 27, 3], [352, 704, 1408, 2816]),
        }
        if bb in backbone_cfg:
            depths, dims = backbone_cfg[bb]
        else:
            raise ValueError(f"Backbone: {bb} not implemented.")


        # Downsample layers
        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm3d(dims[0])
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm3d(dims[i]),
                nn.Conv3d(
                    dims[i], 
                    dims[i+1], 
                    kernel_size = (2,2,2) if i != 2 else (1,2,2), 
                    stride = (2,2,2) if i != 2 else (1,2,2),
                    ),
            )
            self.downsample_layers.append(downsample_layer)

        # Stages
        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        # Head
        self.norm = LayerNorm3d(dims[-1])
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Flatten(start_dim=1, end_dim=-1),
        )
        self.head = nn.Linear(dims[-1], num_classes)

        self.apply(partial(self._init_weights, init_type="gelu-based"))
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m, init_type="relu-based"):
        if init_type == "relu-based":
            if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        elif init_type == "gelu-based":
            if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Linear)):
                trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)

        x = self.pool(x)
        x = self.norm(x)
        x = self.head(x)
        return x

if __name__ == "__main__":

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    cfg= SimpleNamespace()
    cfg.backbone= "atto"
    cfg.encoder_cfg= SimpleNamespace()

    m = ConvNeXtV2Encoder3d(
        cfg= cfg,
        **vars(cfg.encoder_cfg),
    )
    m.cuda().eval()

    # Param count
    n_params= count_parameters(m)
    print(f"Model: {type(m).__name__}")
    print("n_param: {:_}".format(n_params))

    # Normal
    x = torch.ones(8, 3, 16, 128, 128).cuda()
    with torch.no_grad():
        z = m.forward(x)
        print(z.shape)

    # Channels Last
    x = x.to(memory_format=torch.channels_last_3d)
    m = m.to(memory_format=torch.channels_last_3d)
    with torch.no_grad():
        z2 = m.forward(x)
        print(z2.shape)

    # Diff
    print("DIFF: {}".format((z - z2).max()))

The issue was the precision of the LayerNorm3D implementation. As the small precision propagates through the network, the difference becomes more pronounced.. Here is the new LayerNorm3d implementation that works with .to(memory_format=torch.channels_last_3d).

class LayerNorm3d(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            return F.layer_norm(
                x.permute(0,2,3,4,1),
                self.normalized_shape,
                self.weight,
                self.bias,
                self.eps,
                ).permute(0,4,1,2,3)