[codereview] seq2seq, no attention, for review

Hi,

Looking for code review on ths following seq2seq implementation. char-level, with shared embedding, and with encoder added to loss function, for faster initial training.

import torch
from torch import nn, autograd, optim
import numpy as np
import math
import sys
import encoding
# import data_starredwords as data
import data_anki as data


N = 100
# N = 8
N = 16
max_sentence_len = 10
N = 4
print_every = 2
hidden_size = 16
# hidden_size = 1024
hidden_size = 256
# num_epochs = 16
# N = 10


training = data.Data().get_training(N=N)
training = [
    {'input': ex['first'][:max_sentence_len], 'target': ex['second'][:max_sentence_len]}
    for ex in training
]
for n in range(min(N, 16)):
    print(n, training[n])

for i, example in enumerate(training):
    example['input_encoded'] = encoding.encode_passage(example['input'])
    example['target_encoded'] = encoding.encode_passage(example['target'])

V = len(encoding.char_by_idx)
print('vocab size %s' % V)

torch.manual_seed(123)
np.random.seed(123)


class Encoder(nn.Module):
    def __init__(self, embedding):
        super().__init__()
        self.input_size = embedding.weight.size()[0]
        self.hidden_size = embedding.weight.size()[1]
        self.embedding = embedding
        self.rnn_enc = nn.RNN(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=1,
            nonlinearity='tanh'
        )

    def forward(self, x, state):
        x = self.embedding(x)
        x, state = self.rnn_enc(x, state)
        return x, state


class Decoder(nn.Module):
    def __init__(self, embedding):
        super().__init__()
        self.input_size = embedding.weight.size()[0]
        self.hidden_size = embedding.weight.size()[1]
        self.embedding = embedding
        self.rnn_dec = nn.RNN(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=1,
            nonlinearity='tanh'
        )

    def forward(self, x, state):
        x = self.embedding(x)
        x, state = self.rnn_dec(x, state)
        return x, state


optimizer_fn = optim.Adam
# optimizer_fn = optim.SGD

embedding = nn.Embedding(V, hidden_size)
encoder = Encoder(embedding=embedding)
decoder = Decoder(embedding=embedding)
embedding_matrix = embedding.weight

parameters = (
    set(encoder.parameters()) |
    set(decoder.parameters()) |
    set(embedding.parameters()))
opt = optimizer_fn(parameters, lr=0.001)

epoch = 0
while True:
    encoder_debug = ''
    decoder_debug = ''
    for n, ex in enumerate(training):
        input_encoded = ex['input_encoded']
        target_encoded = ex['target_encoded']
        input_len = len(input_encoded)
        target_len = len(target_encoded)

        teacher_forcing = (epoch % print_every) != 0

        loss = 0
        criterion = torch.nn.NLLLoss()

        # encode
        def encode(input_encoded, state):
            global encoder_debug
            enc_loss = 0
            prev_c = encoding.start_code
            input_sentence_verify = ''
            sentence = ''
            # [1:] is to cut off the start token
            # [:-1] is to cut off end token too :-)
            for t, input_c in enumerate(input_encoded[1:]):
                input_c = input_c.item()
                input_sentence_verify += encoding.char_by_idx[input_c]
                pred_c_embedded, state = encoder(autograd.Variable(torch.LongTensor([[prev_c]])), state)
                pred_c = pred_c_embedded.view(-1, hidden_size) @ embedding_matrix.transpose(0, 1)
                _, v = pred_c.max(-1)
                v = v.data[0][0]
                sentence += encoding.char_by_idx[v]
                # want to force encoder to build language model a bit faster than
                # if it has to wait only for gradient from decoder:
                enc_loss += criterion(pred_c, autograd.Variable(
                    torch.LongTensor([input_c])))
                prev_c = input_c
            if n <= 4 and epoch % print_every == 0:
                if n == 0:
                    encoder_debug += 'epoch %s encoder:\n' % epoch
                encoder_debug += '    [%s] => [%s]\n' % (input_sentence_verify, sentence)
            return state, enc_loss

        state = autograd.Variable(torch.zeros(1, 1, hidden_size))
        state, enc_loss = encode(input_encoded, state)
        loss += enc_loss

        # decode
        if True:
            prev_c = encoding.start_code

            output_sentence = ''
            for t, target_c in enumerate(target_encoded[1:]):
                target_c = target_c.item()

                pred_c_embedded, state = decoder(
                    autograd.Variable(torch.LongTensor([[prev_c]])), state)
                pred_c = pred_c_embedded.view(-1, hidden_size) @ embedding_matrix.transpose(0, 1)
                _, v = pred_c.max(-1)
                v = v.data[0][0]
                output_sentence += encoding.char_by_idx[v]
                loss += criterion(pred_c, autograd.Variable(torch.LongTensor(
                    [target_c])))

                if teacher_forcing:
                    prev_c = target_c
                else:
                    # if we're already wrong, let's just abandon...
                    if target_c != v:
                        break
                    prev_c = v
            if n <= 1 and epoch % print_every == 0:
                if n == 0:
                    decoder_debug += 'epoch %s decoder:\n' % epoch
                if not teacher_forcing:
                    decoder_debug += '    [%s] => [%s] [%s]\n' % (
                        ex['input'], ex['target'], output_sentence)
        embedding.zero_grad()
        encoder.zero_grad()
        decoder.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(parameters, 4.0)
        opt.step()

    if encoder_debug != '':
        print(encoder_debug)
    if decoder_debug != '':
        print(decoder_debug)

    epoch += 1
1 Like

(since I’m using similar examples to create tutorial Youtube videos, seems like making sure it’s more or less recognizeable idiomatic pytorch is probably a good idea :slight_smile: )

At the first glance, it looks great!