I’m trying to train a simple RNN model (plain RNN or LSTM). My mode is this:
class Sequence(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim, num_classes): super(Sequence, self).__init__() self.n_layers = 1 self.hidden_size = hidden_dim self.embedding = nn.Embedding(input_dim, embedding_dim) self.encoder = nn.LSTM(embedding_dim, hidden_dim, num_layers=self.n_layers, dropout=0, batch_first=True) self.classify = nn.Linear(hidden_dim, num_classes) def forward(self, smi): smi = self.embedding(smi.long()) smi, hidden = self.encoder(smi, hidden) out = self.classify(smi[:, -1, :]) return out, hidden def init_weights(self): nn.init.xavier_uniform_(self.embedding.weight) nn.init.xavier_uniform_(self.classify.weight) nn.init.constant_(self.classify.bias, 0) for name, param in self.encoder.named_parameters(): if 'weight' in name: nn.init.orthogonal_(param) elif 'bias' in name: nn.init.constant_(param, 0) r_gate = param[int(0.25 * len(param)):int(0.5 * len(param))] nn.init.constant_(r_gate, 1) def init_hidden(self, bsz, device): return (torch.zeros(self.n_layers, bsz, self.hidden_size).to(device), torch.zeros(self.n_layers, bsz, self.hidden_size).to(device))
I don’t think I have to explain that much. The input is an array of integers (I encode each character of each sequence to an integer so that a string of N characters is encoded as an array of N integers): I pass it to an embedding layer, then an LSTM and finally a Linear layer for classification (it’s a multi-class classification task).
The problem is that the training loss is not decreasing, it just fluctuates a lot around the same value. I checked the weights and I cannot notice anything strange.
First, does anyone notice anything wrong in the model itself? It’s the first time I use these models and I’m not sure about the inputs. The input to the model has size
torch.Size([batch_size, 70]), the output of
torch.Size([batch_size, num_classes]) and the labels have dimensions
Another thing I still don’t get is the
hidden tensor, which is pass to an RNN model: is it useful or not? Why some people use it and others don’t?