Custom loss learning only partially

Hi,

I am implementing a paper called Hidding Images in Plain Sight: Deep Steganography using three neural networks to hide a secret image in a cover image and then uncover it. My loss function is a combination of two Minimum Squared Errors:

  1. The difference between the secret image and the output
  2. The difference between the cover image and the hidden image (with the secret image hidden inside)

However, when training, only the second part of the error is optimized (i.e. trying to make the cover image with the embedded hidden secret image look similar to the original cover image). I have no clue to why this might be happening.

image
(in order: secret image, output image, cover image and hidden image)

Any clue would be much appreciated!

Here is the NN architectures, loss function and training code.

Neural Network Architecture:

# Preparation Network (2 conv layers)
class PrepNetwork(nn.Module):
    def __init__(self):
        super(PrepNetwork, self).__init__()
        self.layer1P = nn.Sequential(
            nn.Conv2d(3, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer2P = nn.Sequential(
            nn.Conv2d(50, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=2),
            nn.ReLU(),
            nn.Conv2d(50, 3, kernel_size=5, padding=2))

    def forward(self, p):
        p1 = self.layer1P(p)
        out = self.layer2P(p1)
        return out

# Hiding Network (5 conv layers)
class HidingNetwork(nn.Module):
    def __init__(self):
        super(HidingNetwork, self).__init__()
        self.layer1H = nn.Sequential(
            nn.Conv2d(6, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer2H = nn.Sequential(
            nn.Conv2d(50, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=2),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer3H = nn.Sequential(
            nn.Conv2d(50, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer4H = nn.Sequential(
            nn.Conv2d(50, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=2),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer5H = nn.Sequential(
            nn.Conv2d(50, 50, kernel_size=3, padding=1),
            nn.ReLU(),
#             nn.Conv2d(50, 50, kernel_size=4, padding=2),
#             nn.ReLU(),
            nn.Conv2d(50, 3, kernel_size=5, padding=2))
        
    def forward(self, h):
        h1 = self.layer1H(h)
        h2 = self.layer2H(h1)
        h3 = self.layer3H(h2)
        h4 = self.layer4H(h3)
        out = self.layer5H(h4)
        return out

# Reveal Network (3 conv layers)
class RevealNetwork(nn.Module):
    def __init__(self):
        super(RevealNetwork, self).__init__()
        self.layer1R = nn.Sequential(
            nn.Conv2d(3, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer2R = nn.Sequential(
            nn.Conv2d(50, 50, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(50, 50, kernel_size=4, padding=2),
            nn.ReLU(),
            nn.Conv2d(50, 3, kernel_size=5, padding=2),
            nn.ReLU())
        self.layer3R = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=1, padding=0))

    def forward(self, r):
        r1 = self.layer1R(r)
        r2 = self.layer2R(r1)
        out = self.layer3R(r2)
        return out

# Join three networks in one module
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.m1 = PrepNetwork()
        self.m2 = HidingNetwork()
        self.m3 = RevealNetwork()

    def forward(self, secret, cover):
        x_1 = self.m1(secret)
        mid = torch.cat((x_1, cover), 1)
        x_2 = self.m2(mid)
        x_3 = self.m3(x_2)
        return x_2, x_3

Loss function:


# Implement loss
def customized_loss(S_prime, C_prime, S, C, B):
    ''' Calculates customized loss.'''
    
    loss_cover = torch.nn.functional.mse_loss(C_prime, C)
    loss_secret = torch.nn.functional.mse_loss(S_prime, S)
    loss_all = loss_cover + B * loss_secret
    return loss_all, loss_cover, loss_secret

Training:



    # Save optimizer
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)

    # Iterate over batches performing forward and backward passes
    for epoch in range(num_epochs):

        # Train mode
        net.train()

        # Train one epoch
        for idx, train_batch in enumerate(train_loader):

            data, _  = train_batch

            # Saves secret images and secret covers
            train_covers = data[:len(data)//2]
            train_secrets = data[len(data)//2:]
            
            # Creates variable from secret and cover images
            train_secrets = Variable(train_secrets, requires_grad=False)
            train_covers = Variable(train_covers, requires_grad=False)

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            train_hidden, train_output = net(train_secrets, train_covers)

            # Calculate loss and perform backprop
            train_loss, train_loss_cover, train_loss_secret = customized_loss(train_output, train_hidden, train_secrets, train_covers, beta)
            train_loss.backward()
            optimizer.step()