Pyramid Transformer (Version 2) has been proposed with four stages and the author using the iteration method to iterate four stages. I would like to delete this iteration method and mentioned four stages. It will helpful for me to analyze each stage easily.
Orginal Implementation Please check class PyramidVisionTransformerV2
My implementation divided each stage using indexing
class PyramidVisionTransformerV2(nn.Module):
def __init__(self, *, img_size=112, patch_size=8, loss_type, GPU_ID, in_chans=3, num_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=True):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],embed_dim=embed_dims[3])
self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0]))
self.pos_drop1 = nn.Dropout(p=drop_rate)
self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1]))
self.pos_drop2 = nn.Dropout(p=drop_rate)
self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2]))
self.pos_drop3 = nn.Dropout(p=drop_rate)
self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3]))
self.pos_drop4 = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0], linear=linear)
for i in range(depths[0])])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1], linear=linear)
for i in range(depths[1])])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2], linear=linear)
for i in range(depths[2])])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm = norm_layer(embed_dims[3])
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dims[3]),
)
self.loss_type = loss_type
self.GPU_ID = GPU_ID
if self.loss_type == 'None':
print("no loss for vit_face")
else:
if self.loss_type == 'Softmax':
self.loss = Softmax(in_features=embed_dims[3], out_features=num_classes, device_id=self.GPU_ID)
elif self.loss_type == 'CosFace':
self.loss = CosFace(in_features=embed_dims[3], out_features=num_classes, device_id=self.GPU_ID)
elif self.loss_type == 'ArcFace':
self.loss = ArcFace(in_features=embed_dims[3], out_features=num_classes, device_id=self.GPU_ID)
elif self.loss_type == 'SFace':
self.loss = SFaceLoss(in_features=embed_dims[3], out_features=num_classes, device_id=self.GPU_ID)
elif self.loss_type == 'MagFace':
self.loss = MagFaceHeader(in_features=embed_dims[3], out_features=num_classes)
trunc_normal_(self.pos_embed1, std=.02)
trunc_normal_(self.pos_embed2, std=.02)
trunc_normal_(self.pos_embed3, std=.02)
trunc_normal_(self.pos_embed4, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
# stage 1
x, (H, W) = self.patch_embed1(x)
x = x + self.pos_embed1
x = self.pos_drop1(x)
for blk in self.block1:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 2
x, (H, W) = self.patch_embed2(x)
x = x + self.pos_embed2
x = self.pos_drop2(x)
for blk in self.block2:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 3
x, (H, W) = self.patch_embed3(x)
x = x + self.pos_embed3
x = self.pos_drop3(x)
for blk in self.block3:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 4
x, (H, W) = self.patch_embed4(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed4
x = self.pos_drop4(x)
for blk in self.block4:
x = blk(x, H, W)
x = self.norm(x)
return x[:, 0]
def forward(self, x, label=None, mask=None):
x = self.forward_features(x)
emb = self.mlp_head(x)
if label is not None:
x = self.loss(emb, label)
return x, emb
else:
return emb
Traceback
Traceback (most recent call last):
File "/media/khawar/HDD_Khawar/facerectransformer-main/code/train.py", line 231, in <module>
print_per_layer_stat=True, verbose=True)
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/ptflops/flops_counter.py", line 43, in get_model_complexity_info
_ = flops_model(batch)
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/media/khawar/HDD_Khawar/facerectransformer-main/code/vit_pytorch/pvt_v2.py", line 424, in forward
x = self.forward_features(x)
File "/media/khawar/HDD_Khawar/facerectransformer-main/code/vit_pytorch/pvt_v2.py", line 387, in forward_features
x, (H, W) = self.patch_embed1(x)
ValueError: too many values to unpack (expected 2)