Transformer outputting a constant vector in binary classification

I am new to PyTorch and I’m attempting to train a transformer model on biological sequences to output a binary classification. I have a functional training regiment in the sense that it iterates without errors, however my loss function never decreases and my validation accuracy is constant at 50%. Inspecting the tensor output from one forward pass shows that it is a tensor of constant values. I’m assuming something is wrong with my forward method (perhaps my max pooling?).

My input is of shape (n, 100, 4), where n is the number of samples, 100 is the sequence length, and 4 is the one-hot encoding of DNA.

Here is my model definition

class NeuralNetwork(nn.Module):
	'''
	Build a neural network transformer for one-hot encoded DNA sequences
	'''
	def __init__(self):
		super(NeuralNetwork, self).__init__()
		self.transformer1 = nn.TransformerEncoderLayer(
			d_model= 4,
			nhead= 2,
			batch_first= True,
			dim_feedforward= 1024
		)
		self.transformer2 = nn.TransformerEncoderLayer(
			d_model= 4,
			nhead= 2,
			batch_first= True,
			dim_feedforward= 1024
		)
		self.transformer3 = nn.TransformerEncoderLayer(
			d_model= 4,
			nhead= 2,
			batch_first= True,
			dim_feedforward= 1024
		)
		self.linear = nn.Linear(
			in_features= 4,
			out_features=1
		)

	def forward(self, x):
		x= self.transformer1(x)
		x= self.transformer2(x)
		x= self.transformer3(x)
		x= torch.max(x, dim=1)[0]
		logits = self.linear(x)
		return torch.flatten(logits)

I am using a BCEwithLogitsLoss and the Adam optimizer. If I print the value of output= model(input) I get a constant tensor.

Let me know if I need to include more details, I am attempting to not overclog the post. I posted this in NLP as I assumed you are the resident experts on transformers. Please correct me if I should move it.

I can properly overfit random samples using your model:

model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

x = torch.randn(8, 100, 4)
target = torch.randint(0, 2, (8,)).float()

for _ in range(1000):
    optimizer.zero_grad()
    out = model(x)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()
    print("loss {:.6f}".format(loss.item()))

print(out)
# tensor([-3.4665,  2.0541, -3.1483,  4.0016, -3.1309, -3.1514,  3.8889,  4.5529],
#        grad_fn=<ReshapeAliasBackward0>)
print(target)
# tensor([0., 1., 0., 1., 0., 0., 1., 1.])

so you might need to experiment more with the hyperparameters of your training.

1 Like