Hi folks,
I have read a lot about attention mechanisms in Encoder-Decoder networks.
All examples I’ve found have an Encoder -> Attention -> Decoder Mechanism.
My LSTM which I use for next class prediction (input is a sequence of 10 concatenated Bert-embeddings, so n_input=10 * 768) (more precisely I’m trying to do anomaly detection). It works, I’m getting results which are “ok”, recall is 1.0, but F1 Score of “only” 0.55. So I was thinking of implementing an attention mechanism.
def __init__(self, n_input, n_hidden_units, n_layers, n_classes, tie_weights=False, train_mode=False):
super(LSTM, self).__init__()
self.n_hidden_units = n_hidden_units
self.n_layers = n_layers
self.train_mode = train_mode
# Layers
self.attn = nn.Linear(self.n_hidden_units * 2, n_input)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.lstm = nn.LSTM(input_size=n_input, hidden_size=n_hidden_units, num_layers=n_layers,
dropout=0.2, batch_first=True)
self.decoder = nn.Linear(n_hidden_units, n_classes)
if tie_weights:
if n_input != n_hidden_units:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
def init_weights(self):
initrange = 0.1
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden):
output, hidden = self.lstm(input, hidden)
if self.train_mode: output = nn.Dropout(p=0.1)(output)
decoded = self.decoder(output[:, -1, :])
log_props = F.log_softmax(decoded, dim=1)
return log_props, hidden
def init_hidden(self, bsz, device):
weight = next(self.parameters())
return (weight.new_zeros(self.n_layers, bsz, self.n_hidden_units).to(device),
weight.new_zeros(self.n_layers, bsz, self.n_hidden_units).to(device))
I know from the examples, that the attention layers have to be put somewhere before the lstm. All the examples use many-to-many, but I’m doing many-to-one, and I don’t just want to put the layers and make it work somehow, without knowing exactly how it’s done right.
So if someone could show me where in the code I would have to put what exactly, that would be very helpful.