Non-linearities in Denoising RNN Autoencoders

Hi there,

I’m employing a denoising RNN autoencoder for a project relating to motion capture data. This is my first time using auto encoder architectures and I was just wondering what non-linearities should be placed in these models and where they should go. This is my model as it stands:

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.rnn_enc = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.relu_enc = nn.ReLU()

    def forward(self, x):
        pred, hidden = self.rnn_enc(x, None)
        pred = self.relu_enc(pred)
        return pred

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers

        self.rnn_dec = nn.RNN(input_size=hidden_size, hidden_size=output_size, num_layers=num_layers, batch_first=True)
        self.relu_dec = nn.ReLU()

    def forward(self, x):
        pred, hidden = self.rnn_dec(x, None)
        pred = self.relu_dec(pred)
        return pred

class RNNAE(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(RNNAE, self).__init__()
        self.encoder = EncoderRNN(input_size, hidden_size, num_layers)
        self.decoder = DecoderRNN(hidden_size, input_size, num_layers)

    def forward(self, x):
        encoded_input = self.encoder(x)
        decoded_output = self.decoder(encoded_input)
        return decoded_output

As you can see I have a ReLU non-linearity in each of the encoder and decoder networks but am not sure whether this is the correct implementation for these architectures.
The model learns on the data OK but the MSE loss doesn’t really improve after the first few epochs and I have a suspicion it’s because of these ReLU functions.

Any advice on how to improve this/ whether this is basically correct would be very helpful.