Issue summary
I am working on CNN Encoder-Decoder model for image prediction.
The input is 2 channel and output is 1 channel image, and MSE loss and SGD optimizer are used for the model training.
The training loss seems to be decreasing, but the prediction image has lots of trouble.
Does anyone have the clue for this issue?
I am struggling with the network design and hyper parameter again and again,
but the prediction result is alway similar trouble.
Thank you in advance.
The network details
2 Channel input is float type and 64 by 64 pixel, and the label is 1 channel, float and 64 by 64 pixel.
Each input is been encoded by CNN independently and decoded after the concatenation.
The decoder use CNN and activation function(Relu) and return regression image.
The network is as follows;
class EncoderDecoderRegression(nn.Module):
def __init__(self):
super().__init__()
# Encorder network
self.conv_x1= nn.Sequential(nn.Conv2d(1, 8, 3, 1, 1),
nn.MaxPool2d(2, 2, 0),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1),
nn.MaxPool2d(2, 2, 0),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, 3, 1, 1),
nn.MaxPool2d(2, 2, 0),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.conv_x2 = nn.Sequential(nn.Conv2d(1, 8, 3, 1, 1),
nn.MaxPool2d(2, 2, 0),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1),
nn.MaxPool2d(2, 2, 0),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, 3, 1, 1),
nn.MaxPool2d(2, 2, 0),
nn.BatchNorm2d(32),
nn.ReLU()
)
# Decoder network
self.decoder = nn.Sequential(nn.Upsample(scale_factor=2),
nn.Conv2d(64, 32, 3, 1, 1),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(32, 16, 3, 1, 1),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(16, 1, 3, 1, 1),
nn.ReLU())
def forward(self, x_1, x_2):
x_1 = self.conv_x1(x1)
x_2 = self.conv_x2(x2)
x = torch.cat((x_1, x_2), dim=1)
x = self.decoder(x)
return x
Training process
model = EncoderDecoderRegression().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Train the model
epoch = 50
train_loss, val_loss = [], []
for i in tqdm(range(epoch)):
# Train loop ----------------------------
model.train()
train_batch_loss = []
for y, x in train_loader:
y, x = y.to(device).float(), x.to(device).float()
x1 = x[:, 0:2, :, :]
x2 = x[:, 2:4, :, :]
optimizer.zero_grad()
output = model(x1, x2)
loss = criterion(output, y)
loss.backward()
optimizer.step()
train_batch_loss.append(loss.item())
# val loop ----------------------------
model.eval()
val_batch_loss = []
with torch.no_grad():
for y, x in val_loader:
y, x = y.to(device).float(), x.to(device).float()
x1 = x[:, 0:2, :, :]
x2 = x[:, 2:4, :, :]
output = model(x1, x2)
loss = criterion(output, y)
val_batch_loss.append(loss.item())
# Collect loss
train_loss.append(np.mean(train_batch_loss))
val_loss.append(np.mean(val_batch_loss))
print(i, "Train loss: {a:.3f}, Val loss: {b:.3f}".format(
a=train_loss[-1], b=val_loss[-1])
)