I see the topic,and has the some problem,
the code followed:
class desNet(nn.Module):
def __init__(self):
super(desNet, self).__init__()
# 输入层 CBRP
self.conv0 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn0 = nn.BatchNorm2d(num_features=64)
self.relu0 = nn.ReLU(inplace=True)
self.pool0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
block_transition = {'b1': [64, 96, 128, 160, 192, 224],
't1': [256, 128],
'b2': [128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480],
't2': [512, 256],
'b3': [256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992],
't3': [1024, 512],
'b4': [512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992]
}
self.desblock_net = nn.Sequential()
for i, name in enumerate(block_transition):
if list(name)[0] == 'b':
nums = block_transition.get(name)
for ik, n in enumerate(nums):
self.desblock_net.add_module("{0}_{1}_BN".format(name, ik), nn.BatchNorm2d(num_features=n))
self.desblock_net.add_module("{0}_{1}_Conv".format(name, ik), nn.Conv2d(in_channels=n, out_channels=128, kernel_size=1, stride=1, bias=False))
self.desblock_net.add_module("{0}_{1}_Relu".format(name, ik), nn.ReLU(inplace=True))
self.desblock_net.add_module("{0}_{1}_bn".format(name, ik), nn.BatchNorm2d(num_features=128))
self.desblock_net.add_module("{0}_{1}_conv".format(name, ik), nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False))
self.desblock_net.add_module("{0}_{1}_relu".format(name, ik), nn.ReLU(inplace=True))
elif list(name)[0] == 't':
nums = block_transition[name]
self.desblock_net.add_module("{0}_{1}_bn".format(name, i), nn.BatchNorm2d(num_features=nums[0]))
self.desblock_net.add_module("{0}_{1}_relu".format(name, i), nn.ReLU(inplace=True))
self.desblock_net.add_module("{0}_{1}_conv".format(name, i), nn.Conv2d(in_channels=nums[0], out_channels=nums[1], kernel_size=1, stride=1, bias=False))
self.desblock_net.add_module("{0}_{1}_avgpool".format(name, i), nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
分类层
self.bn = nn.BatchNorm2d(num_features=1024)
self.fc = nn.Linear(in_features=1024, out_features=1000)
def forward(self, input):
out = self.conv0(input)
out = self.bn0(out)
out = self.relu0(out)
out = self.pool0(out)
print(out.size())
out = self.desblock_net(input)
out = self.bn(out)
out = self.fc(out)
return out
if name == ‘main’:
mydesnet = desNet()
print(mydesnet)
x = torch.rand(10, 3, 224, 224)
out = mydesnet(x)
print(out.size())