Hi,
I am practicing pytorch and implemented an keras based seq2seq example:
https://keras.io/examples/lstm_seq2seq/
Below is my implementation:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_path = './eng_fra.txt'
# Vectorize the data.
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
#print('line:',line)
input_text, target_text = line.split('\t')
# We use "tab" as the "start sequence" character
# for the targets, and "\n" as "end sequence" character.
target_text = '\t' + target_text + '\n' # why?
# print('input_text and target_text:',input_text, target_text)
input_texts.append(input_text)
target_texts.append(target_text)
for char in input_text:
if char not in input_characters:
input_characters.add(char)
for char in target_text:
if char not in target_characters:
target_characters.add(char)
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
print('input_characters',input_characters)
num_decoder_tokens = len(target_characters)
print('target_characters',target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])
print('max_encoder_seq_length and max_decoder_seq_length',max_encoder_seq_length,max_decoder_seq_length)
input_token_index = dict(
[(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
[(char, i) for i, char in enumerate(target_characters)])
# define the shapes
encoder_input_data = np.zeros(
(len(input_texts), max_encoder_seq_length, num_encoder_tokens),
dtype='float32')
decoder_input_data = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype='float32')
decoder_target_data = np.zeros(
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
dtype='float32')
# one hot encoding for each word in each sentence
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
for t, char in enumerate(input_text):
encoder_input_data[i, t, input_token_index[char]] = 1.
for t, char in enumerate(target_text):
# decoder_target_data is ahead of decoder_input_data by one timestep
decoder_input_data[i, t, target_token_index[char]] = 1.
if t > 0:
# decoder_target_data will be ahead by one timestep
# and will not include the start character.
decoder_target_data[i, t - 1, target_token_index[char]] = 1.
encoder_input_data=torch.Tensor(encoder_input_data).to(device)
decoder_input_data=torch.Tensor(decoder_input_data).to(device)
decoder_target_data=torch.Tensor(decoder_target_data).to(device)
class encoder(nn.Module):
def __init__(self):
super(encoder,self).__init__()
self.LSTM=nn.LSTM(input_size=num_encoder_tokens,hidden_size=256,batch_first=True)
def forward(self,x):
out,(h,c)=self.LSTM(x)
return h,c
class decoder(nn.Module):
def __init__(self):
super(decoder,self).__init__()
self.LSTM=nn.LSTM(input_size=num_decoder_tokens,hidden_size=256,batch_first=True)
self.FC=nn.Linear(256,num_decoder_tokens)
def forward(self,x, hidden):
out,(h,c)=self.LSTM(x,hidden)
out=self.FC(out)
return out,(h,c)
class seq2seq(nn.Module):
def __init__(self,encoder,decoder):
super(seq2seq,self).__init__()
self.encoder=encoder
self.decoder=decoder
def forward(self,encode_input_data,decode_input_data):
hidden, cell = self.encoder(encode_input_data)
output, (hidden, cell) = self.decoder(decode_input_data, (hidden, cell))
return output
encoder=encoder().to(device)
# encoder_loss = nn.CrossEntropyLoss() # CrossEntropyLoss compute softmax internally in pytorch
# encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
decoder=decoder().to(device)
# decoder_loss = nn.CrossEntropyLoss() # CrossEntropyLoss compute softmax internally in pytorch
# decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001)
model=seq2seq(encoder,decoder).to(device)
optimizer = optim.RMSprop(model.parameters(),lr=0.01)
loss_fun=nn.CrossEntropyLoss()
# model.train()
num_epochs=50
batches=np.array_split(range(decoder_target_data.shape[0]),100)
total_step=len(batches)
for epoch in range(num_epochs):
for i,batch_ids in enumerate(batches):
encoder_input=encoder_input_data[batch_ids]
decoder_input=decoder_input_data[batch_ids]
decoder_target=decoder_target_data[batch_ids]
output = model(encoder_input, decoder_input)
loss=loss_fun(output.view(-1,93).to(device),decoder_target.view(-1,93).max(dim=1)[1].to(device))
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 20 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
(i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
(i, char) for char, i in target_token_index.items())
def decode_sequence(input_seq):
# Encode the input as state vectors.
h,c=model.encoder(input_seq)
# Generate empty target sequence of length 1.
# Populate the first character of target sequence with the start character.
target_seq = torch.zeros((1, 1, num_decoder_tokens)).to(device)
target_seq[0, 0, target_token_index['\t']] = 1.
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, (h_t, c_t) = model.decoder(target_seq,(h,c))
# Sample a token
sampled_token_index = output_tokens.view(-1,93).squeeze(0).max(dim=0)[1].item()
sampled_char = reverse_target_char_index[sampled_token_index]
decoded_sentence += sampled_char
# Exit condition: either hit max length
# or find stop character.
if (sampled_char == '\n' or
len(decoded_sentence) > max_decoder_seq_length):
stop_condition = True
# Update the target sequence (of length 1).
target_seq = torch.zeros((1, 1, num_decoder_tokens)).to(device)
target_seq[0, 0, sampled_token_index] = 1.
# Update states
h,c=h_t,c_t
return decoded_sentence
for seq_index in range(100):
# Take one sequence (part of the training set)
# for trying out decoding.
input_seq = encoder_input_data[seq_index: seq_index + 1]
decoded_sentence = decode_sequence(input_seq)
print('-')
print('Input sentence:', input_texts[seq_index])
print('Decoded sentence:', decoded_sentence)
Basically I follow exact the same data processing steps and structures of encoder/decoder of the keras exampe, but the result is worth than keras, did I do anything wrong