Sorry for replied late. the model definition is shows as following:
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.block1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), # block1(3,622,462)
nn.Conv2d(3, 64, 3, 1, 0), # (64,620,460)
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)), # (64,622, 462)
nn.Conv2d(64, 64, 3, 1, 0), # (64, 620, 460)
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2, return_indices=True), # (64, 310, 230)
# 7
nn.ReflectionPad2d((1, 1, 1, 1)), # block2
nn.Conv2d(64, 128, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, 3, 1, 0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2, return_indices=True),
# 14
nn.ReflectionPad2d((1, 1, 1, 1)), # block3
nn.Conv2d(128, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2, return_indices=True),
# 27
nn.ReflectionPad2d((1, 1, 1, 1)), # block 4
nn.Conv2d(256, 512, 3, 1, 0),
nn.ReLU(inplace=True))
with torch.no_grad():
self.block1[1].weight.copy_(vgg_pretrained_features[0].weight)
self.block1[1].bias.copy_(vgg_pretrained_features[0].bias)
self.block1[4].weight.copy_(vgg_pretrained_features[2].weight)
self.block1[4].bias.copy_(vgg_pretrained_features[2].bias)
self.block1[8].weight.copy_(vgg_pretrained_features[5].weight)
self.block1[8].bias.copy_(vgg_pretrained_features[5].bias)
self.block1[11].weight.copy_(vgg_pretrained_features[7].weight)
self.block1[11].bias.copy_(vgg_pretrained_features[7].bias)
self.block1[15].weight.copy_(vgg_pretrained_features[10].weight)
self.block1[15].bias.copy_(vgg_pretrained_features[10].bias)
self.block1[18].weight.copy_(vgg_pretrained_features[12].weight)
self.block1[18].bias.copy_(vgg_pretrained_features[12].bias)
self.block1[21].weight.copy_(vgg_pretrained_features[14].weight)
self.block1[21].bias.copy_(vgg_pretrained_features[14].bias)
self.block1[24].weight.copy_(vgg_pretrained_features[16].weight)
self.block1[24].bias.copy_(vgg_pretrained_features[16].bias)
self.block1[28].weight.copy_(vgg_pretrained_features[19].weight)
self.block1[28].bias.copy_(vgg_pretrained_features[19].bias)
def forward(self, x):
h = self.block1[0](x)
h = self.block1[1](h)
relu1_1 = self.block1[2](h) # b 64, 622, 462
h = self.block1[3](relu1_1)
h = self.block1[4](h)
p1 = self.block1[5](h)
h, idx1 = self.block1[6](p1)
h = self.block1[7](h)
h= self.block1[8](h)
relu2_1 = self.block1[9](h) # b 128 ,314, 234
h = self.block1[10](relu2_1)
h = self.block1[11](h)
p2 = self.block1[12](h)
h, idx2 = self.block1[13](p2)
h = self.block1[14](h)
h = self.block1[15](h)
relu3_1 = self.block1[16](h) # b 256, 160, 120
h = self.block1[17](relu3_1)
h = self.block1[18](h)
h = self.block1[20](h)
h = self.block1[21](h)
h = self.block1[22](h)
h = self.block1[23](h)
h = self.block1[24](h)
p3 = self.block1[25](h)
h, idx3 = self.block1[26](h)
h = self.block1[27](h)
h = self.block1[28](h)
out = self.block1[29](h) # b 512 85, 65
return out, relu1_1, relu2_1, relu3_1, idx1, idx2, idx3, p1, p2, p3
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.block2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.MaxUnpool2d(2, 2),
# 4
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)), # 13
nn.Conv2d(512, 128, 3, 1, 0),
nn.ReLU(inplace=True),
nn.MaxUnpool2d(2, 2),
# 17
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)), # 20
nn.Conv2d(256, 64, 3, 1, 0),
nn.ReLU(inplace=True),
nn.MaxUnpool2d(2, 2),
# 24
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, 3, 1, 0),
nn.ReLU(inplace=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 3, 3, 1, 0), # 28
nn.Tanh()
)
def forward(self, x, relu1_1, relu2_1, relu3_1, idx1, idx2, idx3, p1, p2, p):
h = self.block2[0](x)
# print("type",type(h))
h = self.block2[1](h)
# print(h[0].size())
h = self.block2[2](h)
# print(h[0].size())
h = self.block2[3](h, idx3)
h = self.block2[4](h)
h = self.block2[5](h)
h = self.block2[6](h)
h = self.block2[7](h)
h = self.block2[8](h)
h = self.block2[9](h)
h = self.block2[10](h)
h = self.block2[11](h)
h = self.block2[12](h)
gn1 = torch.nn.GroupNorm(32,256)
# relu3_1 = torch.nn.GroupNorm(32, 256)(relu3_1)
h = torch.cat((relu3_1, h), 1)
h = self.block2[13](h)
h = self.block2[14](h)
h = self.block2[15](h)
h = self.block2[16](h, idx2)
h = self.block2[17](h)
h = self.block2[18](h)
h = self.block2[19](h)
relu2_1 = torch.nn.GroupNorm(32,128)(relu2_1)
h = torch.cat((relu2_1, h), 1)
h = self.block2[20](h)
h = self.block2[21](h)
h = self.block2[22](h)
h = self.block2[23](h, idx1)
h = self.block2[24](h)
h = self.block2[25](h)
h = self.block2[26](h)
relu1_1 = torch.nn.GroupNorm(32, 64)(relu1_1)
h = torch.cat((relu1_1, h), 1)
h = self.block2[27](h)
h = self.block2[28](h)
out = self.block2[29](h)
return out
class dehazeNetwork(nn.Module):
def __init__(self):
super(dehazeNetwork, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
# self.lossnet = lossNetwork()
def forward(self, x):
out1, relu1_1, relu2_1, relu3_1, idx1, idx2, idx3, p1, p2, p = self.encoder(x)
out_f = self.decoder(out1, relu1_1, relu2_1, relu3_1, idx1, idx2, idx3, p1, p2, p)
# out = self.lossnet(out_f)
return out_f
class dehazeNetworkWithLoss(nn.Module):
def __init__(self, args):
super(dehazeNetworkWithLoss, self).__init__()
net = dehazeNetwork()
lossnet = lossNetwork()
# load pre-trained
if args.load_model is not None and args.load_model != 'none' and args.load_model != 'None':
net.load_pred_model(args.load_model)
print('=======dehazeNetwork\n', net)
criterion = nn.MSELoss()
self.net = net
self.lossnet = lossnet
self.criterion = criterion
def forward(self, img, gt, per_w):
out_r = self.net(img) # state: b 3 224 224
'''out_rl state: (
relu1_2:[b, 64, 224, 224],
relu2_2:[b, 128, 112, 112],
relu3_3:[b, 256, 56, 56],
relu4_3:[b, 512,28,28])'''
out_rrelu1_2, out_rrelu2_2, out_rrelu3_3, out_rrelu4_3 = self.lossnet(out_r)
out_gtrelu1_2, out_gtrelu2_2, out_gtrelu3_3, out_gtrelu4_3 = self.lossnet(gt)
l1 = self.criterion(out_r, gt)
l2=[]
if per_w > 0:
l2_1 = self.criterion(out_rrelu1_2, out_gtrelu1_2)
l2_2 = self.criterion(out_rrelu2_2, out_gtrelu2_2)
l2_3 = self.criterion(out_rrelu3_3, out_gtrelu3_3)
l2_4 = self.criterion(out_rrelu4_3, out_gtrelu4_3)
l2 = l2_1 + l2_2 + l2_3 + l2_4
else:
l2 = l1.data.clone().zero_()
return l1, l2, out_r
wrap_net = dehazeNetworkWithLoss(args)
optimizer = df_optim.getOptimizer([{'params':wrap_net.net.decoder.parameters()}],args, args.optm, args.lr)
if torch.cuda.is_available():
if len(gids) > 1:
wrap_net = nn.DataParallel(wrap_net, device_ids=gids)
wrap_net.cuda()
cudnn.benchmark = True
then is the training part:
def train(epoch):
wrap_net.train()
train_loss = 0.0
train_time = 0.0
loading_time = 0.0
end = time.time()
for bi,(hinputs, ginputs) in enumerate(tr_loader):
if torch.cuda.is_available():
hinputs, ginputs = hinputs.cuda(), ginputs.cuda()
loading_time += time.time() - end
optimizer.zero_grad()
loss_r, loss_p, target = wrap_net(hinputs, ginputs, args.per_w).cpu()
loss = args.rec_w * loss_r + args.per_w * loss_p
if len(gids) > 1:
loss.backward()
train_loss += torch.sum(loss.item())
else:
loss.backward()
train_loss += loss.item()
optimizer.step()
train_time += time.time() - end
end = time.time()
if bi % args.print_freq == 1:
print('trainig epoch: %d, minibatch: %d, loss:%f, total time/mb: %f ms, running time/mb: %fms'%(
epoch, bi, train_loss/(bi+1),
train_time/(bi+1)*1000.0, (train_time-loading_time)/(bi+1)*1000.0))
print('ep%d mb%d loss details:'%(epoch,bi), [x.item() for x in loss_r],[x.item() for x in loss_p])
return train_loss/len(tr_loader), train_time, loading_time
thanks