Constant training/validation accuracy

Hi,
I have short reads of DNA sequence of length 48 which is composed of four DNA nucleotides (“A”, “T”, “C”, “G”). So the problem I am working on is a binary classification problem to distinguish between the sequences that belong to Class I and others belong to Class II. My model is based on a one layer of conv1D and 2-3 fully connected layers. During training, I keep getting almost constant training accuracy/training loss and validation accuracy/validation loss. I would appreciate your help a lot. Below is code that I am running.

class CNNNet(nn.Module):
    #def __init__(self, input_size, hidden_size, num_layers, d_out):
    def __init__(self, voc_size, emb_dim, d_out):
        super(CNNNet,self).__init__()
        
        self.voc_size = voc_size
        self.emb_dim = emb_dim
        self.d_out = d_out
        
        self.embedding = nn.Embedding(self.voc_size, self.emb_dim)
        self.cnn1 = nn.Conv1d(in_channels=35, out_channels=128, kernel_size=11)
        self.maxpool1 = nn.MaxPool1d(kernel_size=5, stride=1)

        self.fc1 = nn.Linear(4352, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 256)
        self.fc4 = nn.Linear(256, 64)
        self.fc5 = nn.Linear(64, 1)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(self.d_out)

    def forward(self, x):
        embeds = self.embedding(x)
        embeds = embeds.view(embeds.shape[0], embeds.shape[2], embeds.shape[1])     
        cnn_layer1 = self.relu(self.cnn1(embeds))
        mpool1 = self.maxpool1(cnn_layer1)
        x = mpool1.view(mpool1.size(0), -1)
        out = self.relu(self.dropout(self.fc1(x)))
        out = self.relu(self.dropout(self.fc2(out)))
        out = self.relu(self.dropout(self.fc3(out)))
        out = self.relu(self.dropout(self.fc4(out)))
        out = self.sigmoid(self.relu((self.fc5(out))))
        return (out)

Here is the progress of training

Epoch:  0  /Loss is:  35.44  /Acc:  0.499  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  1  /Loss is:  49.983  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  2  /Loss is:  49.983  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  3  /Loss is:  49.983  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  4  /Loss is:  49.98  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  5  /Loss is:  49.986  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  6  /Loss is:  49.985  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  7  /Loss is:  49.99  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  8  /Loss is:  49.98  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  9  /Loss is:  49.983  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  10  /Loss is:  49.98  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  11  /Loss is:  49.987  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  12  /Loss is:  49.984  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  13  /Loss is:  49.987  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5
 Epoch:  14  /Loss is:  49.98  /Acc:  0.5  /Val_Loss:  50.0 / Val Acc:  0.5

The usage of sigmoid on top of relu looks a bit strange, so you might want to check if removing the relu might help. Also, assuming you are using nn.BCELoss, replace it with nn.BCEWithLogitsLoss and remove the sigmoid for more numerical stability. Once this is done, try to overfit a small dataset (e.g. just 10 samples) by playing around with some hyperparameters.

I have made the changes by passing the final layer with/without relu:

class CNNNet(nn.Module):
    #def __init__(self, input_size, hidden_size, num_layers, d_out):
    def __init__(self, voc_size, emb_dim, d_out):
        super(CNNNet,self).__init__()
        
        self.voc_size = voc_size
        self.emb_dim = emb_dim
        self.d_out = d_out
        
        self.embedding = nn.Embedding(self.voc_size, self.emb_dim)
        self.cnn1 = nn.Conv1d(in_channels=35, out_channels=128, kernel_size=11)
        self.maxpool1 = nn.MaxPool1d(kernel_size=5, stride=1)

        self.fc1 = nn.Linear(4352, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 256)
        self.fc4 = nn.Linear(256, 64)
        self.fc5 = nn.Linear(64, 1)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(self.d_out)

    def forward(self, x):
        embeds = self.embedding(x)
        embeds = embeds.view(embeds.shape[0], embeds.shape[2], embeds.shape[1])     
        cnn_layer1 = self.relu(self.cnn1(embeds))
        mpool1 = self.maxpool1(cnn_layer1)
        x = mpool1.view(mpool1.size(0), -1)
        out = self.relu(self.dropout(self.fc1(x)))
        out = self.relu(self.dropout(self.fc2(out)))
        out = self.relu(self.dropout(self.fc3(out)))
        out = self.relu(self.dropout(self.fc4(out)))
        out = self.relu(self.dropout(self.fc5(out)))
        return (out)

I replaced nn.BCELoss woth nn.BCEWithLogitsLoss but I still keep getting the same training accuracy and validation accuracy. The loss for both training and validation is increasing. Don’t know if there is something wrong with my final layer.

Epoch:  0  /Loss is:  22.967  /Acc:  0.5  /Val_Loss:  0.305 / Val Acc:  0.5
 Epoch:  1  /Loss is:  26935.934  /Acc:  0.5  /Val_Loss:  55.215 / Val Acc:  0.499
 Epoch:  2  /Loss is:  1912778.449  /Acc:  0.501  /Val_Loss:  6714.289 / Val Acc:  0.499
 Epoch:  3  /Loss is:  63861942.892  /Acc:  0.5  /Val_Loss:  160974.908 / Val Acc:  0.497
 Epoch:  4  /Loss is:  36171148.704  /Acc:  0.5  /Val_Loss:  29366.505 / Val Acc:  0.5
 Epoch:  5  /Loss is:  8249850.308  /Acc:  0.501  /Val_Loss:  67750.454 / Val Acc:  0.501
 Epoch:  6  /Loss is:  54260504.278  /Acc:  0.5  /Val_Loss:  169734.745 / Val Acc:  0.501
 Epoch:  7  /Loss is:  30756635.652  /Acc:  0.5  /Val_Loss:  31760.747 / Val Acc:  0.502
 Epoch:  8  /Loss is:  36212964.045  /Acc:  0.5  /Val_Loss:  383223.904 / Val Acc:  0.502
 Epoch:  9  /Loss is:  102597196.442  /Acc:  0.5  /Val_Loss:  710555.625 / Val Acc:  0.5
 Epoch:  10  /Loss is:  312780892.063  /Acc:  0.5  /Val_Loss:  1806529.0 / Val Acc:  0.499
 Epoch:  11  /Loss is:  194196309.733  /Acc:  0.501  /Val_Loss:  371710.664 / Val Acc:  0.5
 Epoch:  12  /Loss is:  79394308.874  /Acc:  0.5  /Val_Loss:  306873.801 / Val Acc:  0.5

Remove the self.relu and self.dropout usage from the output and return the output of self.fc5 instead.
Also, check that this view operation is right, as you might be interleaving the tensor:

embeds = embeds.view(embeds.shape[0], embeds.shape[2], embeds.shape[1])

If you want to swap axes, use permute.
This code trying to overfit a random data sample fails using your model but works after the proposed changes were made:

class CNNNet(nn.Module):
    #def __init__(self, input_size, hidden_size, num_layers, d_out):
    def __init__(self, voc_size, emb_dim, d_out):
        super(CNNNet,self).__init__()
        
        self.voc_size = voc_size
        self.emb_dim = emb_dim
        self.d_out = d_out
        
        self.embedding = nn.Embedding(self.voc_size, self.emb_dim)
        self.cnn1 = nn.Conv1d(in_channels=35, out_channels=128, kernel_size=11)
        self.maxpool1 = nn.MaxPool1d(kernel_size=5, stride=1)

        self.fc1 = nn.Linear(4352, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 256)
        self.fc4 = nn.Linear(256, 64)
        self.fc5 = nn.Linear(64, 1)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(self.d_out)

    def forward(self, x):
        embeds = self.embedding(x)
        embeds = embeds.permute(0, 2, 1).contiguous()
        cnn_layer1 = self.relu(self.cnn1(embeds))
        mpool1 = self.maxpool1(cnn_layer1)
        x = mpool1.view(mpool1.size(0), -1)
        out = self.relu(self.dropout(self.fc1(x)))
        out = self.relu(self.dropout(self.fc2(out)))
        out = self.relu(self.dropout(self.fc3(out)))
        out = self.relu(self.dropout(self.fc4(out)))
        out = self.fc5(out)
        return (out)
    
model = CNNNet(10, 35, 0.5)
x = torch.randint(0, 10, (64, 48))
target = torch.randint(0, 2, (64, 1)).float()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(100):
    optimizer.zero_grad()
    out = model(x)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()
    print("epoch {}, loss {}".format(epoch, loss.item()))