from nets import mobilenet
import torch.nn as nn
from nets.layers import convolution
class MobileCenterNet(nn.Module):
def __init__(self,device):
super(MobileCenterNet, self).__init__()
self.dimension = 128
self.device = device
self.backbone = mobilenet.mobilenetv2().to(self.device)
self.p3 = convolution(k=1, inp_dim=320, out_dim=self.dimension)
self.up_p3 = nn.Upsample(scale_factor=2,mode='nearest')
self.p2 = convolution(k=11, inp_dim=96, out_dim=self.dimension)
self.up_p2 = nn.Upsample(scale_factor=2,mode='nearest')
self.p1 = convolution(k=11, inp_dim=32, out_dim=self.dimension)
self.up_p1 = nn.Upsample(scale_factor=2,mode='nearest')
self.p0 = convolution(k=1, inp_dim=24, out_dim=self.dimension)
self.feature_layer = convolution(k=3,inp_dim=128,out_dim=self.dimension)
self.heatmap_layer = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, 1, padding=0, stride=1)
)
self.wh_layer = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 2, 1, padding=0, stride=1),
)
self.reg_layer = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 2, 1, padding=0, stride=1),
)
def forward(self, x):
# feature extractor
fm0, fm1, fm2, fm3 = self.backbone(x)
fm3 = self.p3(fm3)
up_fm3 = self.up_p3(fm3)
reduce_dim_fm2 = self.p2(fm2)
fm2 = 0.5 * up_fm3 + 0.5 * reduce_dim_fm2
up_fm2 = self.up_p2(fm2)
reduce_dim_fm1 = self.p1(fm1)
fm1 = 0.5 * up_fm2 + 0.5 * reduce_dim_fm1
up_fm1 = self.up_p1(fm1)
reduce_dim_fm0 = self.p0(fm0)
fm0 = 0.5 * up_fm1 + 0.5 * reduce_dim_fm0
features = self.feature_layer(fm0)
# detector
hm = self.heatmap_layer(features)
wh = self.wh_layer(features)
reg = self.reg_layer(features)
output = {'hm': hm, 'wh': wh, 'reg': reg}
return output
def get_centernet(device):
model = MobileCenterNet(device=device)
return model
I’m trying to train a model with centerNet and mobilenet but I think there are some problem with my code.
Could you find the problems,