hi guys.
i want to train LSTM network for filling a form using a given text.
now when i turn my inputs and labels to tokens, i have something like:
tensor([[ 0, 15, 14, ..., 0, 0, 0],
[ 0, 15, 14, ..., 0, 0, 0],
[ 0, 15, 14, ..., 0, 0, 0],
[ 0, 15, 14, ..., 0, 0, 0],
[ 0, 15, 14, ..., 43, 0, 0],
[ 0, 15, 14, ..., 0, 0, 0]])
for input text, and :
tensor([[ 175, 153, 202, 242, 482],
[ 186, 26, 627, 363, 393],
[ 171, 26, 190, 835, 246],
[ 160, 149, 313, 449, 337],
[ 157, 152, 1001, 1024, 408],
[ 186, 149, 551, 1020, 288],
[ 158, 154, 445, 1051, 259],
[ 159, 149, 791, 1104, 432],
[ 166, 155, 553, 780, 259],
[ 159, 153, 767, 462, 246],
[ 168, 149, 188, 378, 302],
[ 179, 149, 237, 461, 291],
[ 163, 152, 1000, 690, 235],
[ 165, 153, 223, 968, 289],
[ 164, 154, 445, 236, 215],
[ 159, 153, 442, 390, 807],
[ 183, 154, 765, 1036, 488],
[ 181, 154, 221, 970, 480],
[ 157, 153, 905, 473, 249],
[ 161, 154, 369, 242, 479],
[ 169, 153, 271, 196, 253],
[ 178, 149, 476, 1110, 488],
[ 159, 26, 241, 222, 342],
[ 163, 152, 534, 467, 232],
[ 179, 152, 818, 200, 302],
[ 173, 26, 1169, 602, 267],
[ 158, 149, 306, 581, 329],
[ 179, 153, 191, 304, 492],
[ 174, 149, 788, 311, 301],
[ 173, 149, 545, 883, 302],
[ 163, 152, 910, 543, 411],
[ 183, 26, 446, 319, 505]])
for my labels.
now i am confused how to put them inside the loss function.
also my network class is :
class FormFillingModel(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_tags):
super(FormFillingModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
self.fc1 = nn.Linear(hidden_size, num_tags) # For age
self.fc2 = nn.Linear(hidden_size, num_tags) # For action
self.fc3 = nn.Linear(hidden_size, num_tags) # For first_name
self.fc4 = nn.Linear(hidden_size, num_tags) # For last_name
self.fc5 = nn.Linear(hidden_size, num_tags) # For date
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = x[:, -1, :]
return self.fc1(x), self.fc2(x), self.fc3(x), self.fc4(x), self.fc5(x)