Hi there,
I am trying to implement a convnextv2 model in 3D. I would like the model to support torch.channels_last_3d
memory format and torch.compile()
.
I think I am really close but I can not get the input/output to match after conversion. Any ideas?
from types import SimpleNamespace
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import trunc_normal_, DropPath
###########
## Utils ##
###########
class LayerNorm3d(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
return x
class GRN3d(nn.Module):
""" 3D GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,2,3), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
############
## Blocks ##
############
class Block(nn.Module):
""" ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm3d(dim, data_format="channels_last")
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN3d(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 4, 1) # (N, D, C, H, W) -> (N, D, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 4, 1, 2, 3) # (N, D, H, W, C) -> (N, C, D, H, W)
x = input + self.drop_path(x)
return x
class ConvNeXtV2Encoder3d(nn.Module):
"""
ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 700
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
cfg: SimpleNamespace,
in_chans: int = 3,
drop_path_rate: float = 0.0,
head_init_scale: float = 1.0,
num_classes: float = 700,
):
super().__init__()
self.cfg= cfg
# Backbone init cfg
bb= self.cfg.backbone
backbone_cfg = {
"atto": ([2, 2, 6, 2], [40, 80, 160, 320]),
"femto": ([2, 2, 6, 2], [48, 96, 192, 384]),
"pico": ([2, 2, 6, 2], [64, 128, 256, 512]),
"nano": ([2, 2, 8, 2], [80, 160, 320, 640]),
"tiny": ([3, 3, 9, 3], [96, 192, 384, 768]),
"base": ([3, 3, 27, 3], [128, 256, 512, 1024]),
"large": ([3, 3, 27, 3], [192, 384, 768, 1536]),
"huge": ([3, 3, 27, 3], [352, 704, 1408, 2816]),
}
if bb in backbone_cfg:
depths, dims = backbone_cfg[bb]
else:
raise ValueError(f"Backbone: {bb} not implemented.")
# Downsample layers
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm3d(dims[0])
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm3d(dims[i]),
nn.Conv3d(
dims[i],
dims[i+1],
kernel_size = (2,2,2) if i != 2 else (1,2,2),
stride = (2,2,2) if i != 2 else (1,2,2),
),
)
self.downsample_layers.append(downsample_layer)
# Stages
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
# Head
self.norm = LayerNorm3d(dims[-1])
self.pool = nn.Sequential(
nn.AdaptiveAvgPool3d((1, 1, 1)),
nn.Flatten(start_dim=1, end_dim=-1),
)
self.head = nn.Linear(dims[-1], num_classes)
self.apply(partial(self._init_weights, init_type="gelu-based"))
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m, init_type="relu-based"):
if init_type == "relu-based":
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif init_type == "gelu-based":
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
x = self.pool(x)
x = self.norm(x)
x = self.head(x)
return x
if __name__ == "__main__":
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
cfg= SimpleNamespace()
cfg.backbone= "atto"
cfg.encoder_cfg= SimpleNamespace()
m = ConvNeXtV2Encoder3d(
cfg= cfg,
**vars(cfg.encoder_cfg),
)
m.cuda().eval()
# Param count
n_params= count_parameters(m)
print(f"Model: {type(m).__name__}")
print("n_param: {:_}".format(n_params))
# Normal
x = torch.ones(8, 3, 16, 128, 128).cuda()
with torch.no_grad():
z = m.forward(x)
print(z.shape)
# Channels Last
x = x.to(memory_format=torch.channels_last_3d)
m = m.to(memory_format=torch.channels_last_3d)
with torch.no_grad():
z2 = m.forward(x)
print(z2.shape)
# Diff
print("DIFF: {}".format((z - z2).max()))