I’m trying to solve an nlp classification problem with a LSTM. The code for the model is defined here:
class LSTM(nn.Module):
def __init__(self, hidden_size, embedding_size=66 ):
super().__init__()
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first = True, bidirectional = True)
self.fc = nn.Linear(2*hidden_size,2)
def forward(self, input_seq):
output, (hidden_state, cell_state) = self.lstm(input_seq)
hidden_state = torch.cat((hidden_state[-1,:], hidden_state[-2,:]), -1)
logits = self.fc(hidden_state)
return nn.LogSoftmax(dim=1)(logits)
And the function I’m using to train this model is here:
def train_loop(dataloader, model, loss_fn, optimizer):
loss_fn = loss_fn
size = len(dataloader.dataset)
model.train()
zeros = 0
for batch, (X, y) in enumerate(dataloader):
# Transform string into tensor
tensor = torch.zeros(1,len(X[0]),66)
for i in range(len(X[0])):
tensor[0][i][ctoi[X[0][i]]] = 1
pred = model(tensor)
target = torch.zeros(2, dtype=torch.long)
target[y] = 1
if batch % 100 == 0:
print(pred.squeeze(), target)
loss = loss_fn(pred.squeeze(), target)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if pred.squeeze().argmax() == 0:
zeros += 1
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
print(f'In trainning predicted {zeros} zeroes out of {size} samples')
The X’s are still strings, that’s why I need to convert them to tensors before running it through the model. The y’s are either a 0 or 1 (since its a binary classification problem), that I need to convert to a tensor of shape (2,) to run through the loss function.
For some reason I keep getting the same class predicted for every input. The classes are not even that unbalanced (~45% to 55%), and I’ve tried changing the weights of the classes in the loss function with no improvements, it either converges to predicting always a 0 or always a 1. Most of the time it it converges to predicting always a 0, which makes even less sense because what happens usually is that the class 0 has less samples than class 1.