import torch
import torch.nn as nn
import torch.nn.functional as F
class BiLSTMWithAttention(nn.Module):
def init(self, embedding_dim=768, hidden_dim=128, num_classes=3):
super(BiLSTMWithAttention, self).init()
self.bilstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.attn = nn.Linear(2 * hidden_dim, 1)
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(32, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
# x shape: (batch_size, seq_len, embedding_dim)
lstm_out, _ = self.bilstm(x) # lstm_out: (batch, seq_len, 2*hidden_dim)
attn_scores = self.attn(lstm_out) # (batch, seq_len, 1)
attn_weights = torch.softmax(attn_scores, dim=1) # (batch, seq_len, 1)
context = torch.sum(attn_weights * lstm_out, dim=1) # (batch, 2*hidden_dim)
x = self.dropout(F.relu(self.fc1(context))) # (batch, 128)
return self.fc2(x) # (batch, num_classes)