Hello, I wrote this network based on this paper.
As you can see I am using two for loops to loop through each essay in the batch and each sentence in an essay.
I am sure there is a better and more efficient way to do it without these for loops, but I could not figure it out.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch import Tensor
from typing import Optional, Tuple
import torch.nn.functional as F
class Attention(nn.Module):
# implementation from https://mlwhiz.com/blog/2019/03/09/deeplearning_architectures_text_classification/
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
super(Attention, self).__init__(**kwargs)
self.supports_masking = True
self.bias = bias
self.feature_dim = feature_dim
self.step_dim = step_dim
self.features_dim = 0
weight = torch.zeros(feature_dim, 1)
nn.init.kaiming_uniform_(weight)
self.weight = nn.Parameter(weight)
if bias:
self.b = nn.Parameter(torch.zeros(step_dim))
def forward(self, x, mask=None):
feature_dim = self.feature_dim
step_dim = self.step_dim
eij = torch.mm(x.contiguous().view(-1, feature_dim),self.weight).view(-1, step_dim)
if self.bias:
eij = eij + self.b
eij = torch.tanh(eij)
a = torch.exp(eij)
if mask is not None:
a = a * mask
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
weighted_input = x * torch.unsqueeze(a, -1)
return torch.sum(weighted_input, 1)
class STL(nn.Module):
def __init__(self, embedding_dim, hidden_dim, filters, max_word, max_sent):
super(STL, self).__init__()
# CNN layer
self.word_level_cnn = nn.Conv1d(embedding_dim, filters, kernel_size=5)
# Dropout layer
self.dropout = nn.Dropout(p=0.5)
# LSTM layer
self.sentence_level_rnn = nn.LSTM(max_word-5+1, hidden_dim)
# Dense layer
self.dense = nn.Linear(max_sent, 1)
# word level attention
self.wiz_word_attention = Attention(max_word-5+1, filters)
# sentence level attnetion
self.wiz_sent_attention = Attention(max_sent, filters)
def forward(self, x):
batch_size, max_sent, max_word, embedding_dim = x.size()
essays_rep = []
for i in range(batch_size): # looping through the essays in the batch
sent_rep = []
lstm_rep = None
for j in range(max_sent): # looping through each sentence of the essay
# input size is (max_words, embedding_dim)
sentence = x[i][j]
# passing each sentence to a cnn with size (batch=1, embedding_dim, max_word), output is size (batch, filters, max_words-5+1)
cnn_out = self.word_level_cnn(sentence.permute(1, 0).unsqueeze(0))
# applying attention pooling
pooled_output = self.wiz_word_attention(cnn_out)
# applying LSTM for each sentence
if lstm_rep == None:
lstm_output = self.sentence_level_rnn(pooled_output.unsqueeze(0))
else:
lstm_output = self.sentence_level_rnn(pooled_output.unsqueeze(0), lstm_rep[1])
lstm_rep = lstm_output
lstm_output = lstm_output[0]
sent_rep.append(lstm_output.squeeze(0))
essays_rep.append(torch.stack(sent_rep))
x = torch.stack(essays_rep)
x = x.squeeze(2)
# sentence attention pooling
pooled_output = self.wiz_sent_attention(x)
# Dense layer then sigmoid for normanlization
score = self.dense(pooled_output)
# print('after dense', score)
score = nn.Sigmoid()(score)
return score