I have read a lot about attention mechanisms in Encoder-Decoder networks.
All examples I’ve found have an Encoder -> Attention -> Decoder Mechanism.
My LSTM which I use for next class prediction (input is a sequence of 10 concatenated Bert-embeddings, so n_input=10 * 768) (more precisely I’m trying to do anomaly detection). It works, I’m getting results which are “ok”, recall is 1.0, but F1 Score of “only” 0.55. So I was thinking of implementing an attention mechanism.
def __init__(self, n_input, n_hidden_units, n_layers, n_classes, tie_weights=False, train_mode=False): super(LSTM, self).__init__() self.n_hidden_units = n_hidden_units self.n_layers = n_layers self.train_mode = train_mode # Layers self.attn = nn.Linear(self.n_hidden_units * 2, n_input) self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) self.lstm = nn.LSTM(input_size=n_input, hidden_size=n_hidden_units, num_layers=n_layers, dropout=0.2, batch_first=True) self.decoder = nn.Linear(n_hidden_units, n_classes) if tie_weights: if n_input != n_hidden_units: raise ValueError('When using the tied flag, nhid must be equal to emsize') self.decoder.weight = self.encoder.weight self.init_weights() def init_weights(self): initrange = 0.1 self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, input, hidden): output, hidden = self.lstm(input, hidden) if self.train_mode: output = nn.Dropout(p=0.1)(output) decoded = self.decoder(output[:, -1, :]) log_props = F.log_softmax(decoded, dim=1) return log_props, hidden def init_hidden(self, bsz, device): weight = next(self.parameters()) return (weight.new_zeros(self.n_layers, bsz, self.n_hidden_units).to(device), weight.new_zeros(self.n_layers, bsz, self.n_hidden_units).to(device))
I know from the examples, that the attention layers have to be put somewhere before the lstm. All the examples use many-to-many, but I’m doing many-to-one, and I don’t just want to put the layers and make it work somehow, without knowing exactly how it’s done right.
So if someone could show me where in the code I would have to put what exactly, that would be very helpful.