PackedSequences from pack_padded_sequences on real data as a batch input to GRU / LSTM

Hello,

I was going through PyTorch tutorials and stuck at the name classification tutorial: http://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html with my slight modifications.

After two days of googling and lurking at the forum I can’t figure out how to properly use packed sequences in batches with RNN. Could you please help me with that? Code is properly working if I feed objects one-by-one to the network. I use some not so obvious approaches to vectorization with numpy because I’m more familiar with it, but I hope to learn more “pytorch way” in the future. Also I know about torchtext but I would like to understand the low-level mechanics of PyTorch first.

Code:

import random
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch import autograd
from torch.autograd import Variable 
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import glob
import string
import unicodedata

import numpy as np
import time
import math

all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
N_HIDDEN = 16
USE_CUDA = torch.cuda.is_available()

class RNN(torch.nn.Module):

	def __init__(self, n_classes, gru_size = N_HIDDEN):
		super(RNN, self).__init__()
		self.gru = nn.GRU(input_size=n_letters, hidden_size=gru_size)
		self.linear = nn.Linear(gru_size, n_classes)

	def forward(self, input, lengths):
		output, self.hidden = self.gru(input, self.hidden)
		x = output[-1]
		x = F.relu(x)
		x = self.linear(x)
		x = F.log_softmax(x)
		return x

	def init_hidden(self, batch_size):
		hidden = autograd.Variable(torch.zeros(1, batch_size, N_HIDDEN))
		if USE_CUDA:
			self.hidden = hidden.cuda()
		else:
			self.hidden = hidden


def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )


def vectorize(seqs):
	X = np.zeros((30, len(seqs), 1, n_letters), dtype=np.uint8)
	for i, sequence in enumerate(seqs):
		for t, letter in enumerate(sequence):
			X[t, i, 0, all_letters.find(letter)] = 1
	return X


# Slice numpy arrays by the given indices and convert them to Variables
def get_seq(ind):
 	if USE_CUDA:
 		return (autograd.Variable(torch.from_numpy(seq_arr)[:, ind]).cuda(), 
 				autograd.Variable(torch.from_numpy(cat_vec[ind])).cuda())
 	else:
 		return (autograd.Variable(torch.from_numpy(seq_arr[:, ind]).type(torch.FloatTensor)), 
 				autograd.Variable(torch.from_numpy(cat_vec[ind]).type(torch.LongTensor)))


def timeSince(since):
    now = time.time()
    sec = now - since
    s = sec
    m = math.floor(sec / 60)
    s -= m * 60
    return '%dm %ds' % (m, s), sec


file_paths = glob.glob("./data/names/*.txt")
len_vec = []
cat_vec = []
seq_vec = []
all_categories = []
cat_dict = {}
for file in file_paths:
	cat = file.split('/')[-1].split('.')[0]
	if cat not in cat_dict:
		cat_dict[cat] = len(cat_dict)
	with open(file, encoding='utf-8') as inp:
		for line in inp:
			seq_vec.append(line.strip())
			cat_vec.append(cat_dict[cat])
			all_categories.append(cat)
			len_vec.append(len(line.strip()))


# Sort sequences in descending order by their length
temp = sorted(zip(seq_vec, cat_vec), reverse=True, key = lambda x: len(x[0]))
seq_arr = vectorize([x[0] for x in temp])
cat_vec = np.array([x[1] for x in temp])
len_vec = np.array([len(x[0]) for x in temp])

rnn = RNN(len(cat_dict))

if USE_CUDA:
	rnn = rnn.cuda()

criterion = nn.NLLLoss()
learning_rate = .002
optimizer = torch.optim.RMSprop(rnn.parameters(), lr=learning_rate)

start = time.time()

for iter in range(10):
	all_loss = []
	print(iter+1)

	for ind in range(64):
        # Batches of size 16, just get some random indices
        # And yes, I know this isn't the right way to make batches
		inds = sorted(np.random.choice(20000, 16, replace=False)) 
		x, y = get_seq(inds)
		x = pack_padded_sequence(x, len_vec[inds])

		rnn.init_hidden(len(inds))
		rnn.zero_grad()
		y_pred = rnn(x, len_vec[inds])

		loss = criterion(y_pred, y)
		loss.backward()

		optimizer.zero_grad()
		optimizer.step()

		all_loss.append(loss.data[0])

	tstr, sec = timeSince(start)
	print(round(sec / (iter+1), 3))
	print(sum(all_loss))

	print()

Error:

  File "trch.py", line 158, in <module>
    y_pred = rnn(x, len_vec[inds])
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "trch.py", line 38, in forward
    output, self.hidden = self.gru(input, self.hidden)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/modules/rnn.py", line 91, in forward
    output, hidden = func(input, self.all_weights, hx)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/_functions/rnn.py", line 343, in forward
    return func(input, *fargs, **fkwargs)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/_functions/rnn.py", line 243, in forward
    nexth, output = func(input, hidden, weight)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/_functions/rnn.py", line 83, in forward
    hy, output = inner(input, hidden[l], weight[l])
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/_functions/rnn.py", line 154, in forward
    hidden = (inner(step_input, hidden[0], *weight),)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/_functions/rnn.py", line 53, in GRUCell
    gi = F.linear(input, w_ih, b_ih)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/functional.py", line 449, in linear
    return state(input, weight) if bias is None else state(input, weight, bias)
  File "/Users/vdn/anaconda/lib/python3.5/site-packages/torch/nn/_functions/linear.py", line 10, in forward
    output.addmm_(0, 1, input, weight.t())
RuntimeError: matrices expected, got 3D, 2D tensors at /Users/soumith/miniconda2/conda-bld/pytorch_1493757035034/work/torch/lib/TH/generic/THTensorMath.c:1232

After careful investigation I found the error and made it work. Here is the final code with highlighted errors:

import random
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch import autograd
from torch.autograd import Variable 
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import glob
import string
import unicodedata

import numpy as np
import time
import math

all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
N_HIDDEN = 16
USE_CUDA = torch.cuda.is_available()

class RNN(torch.nn.Module):

	def __init__(self, n_classes, gru_size = N_HIDDEN):
		super(RNN, self).__init__()
		self.gru = nn.GRU(input_size=n_letters, hidden_size=gru_size)
		self.linear = nn.Linear(gru_size, n_classes)

	def forward(self, input, lengths):
                # New: return and process the last hidden layer values from each input sequence
		output, y = self.gru(input, self.hidden)
		x = y.squeeze()
		x = F.relu(x)
		x = self.linear(x)
		x = F.log_softmax(x)
		return x

	def init_hidden(self, batch_size):
		hidden = autograd.Variable(torch.zeros(1, batch_size, N_HIDDEN))
		if USE_CUDA:
			self.hidden = hidden.cuda()
		else:
			self.hidden = hidden


def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )


def vectorize(seqs):
        # Error: unnecessary dimension
	# X = np.zeros((30, len(seqs), 1, n_letters), dtype=np.uint8)
	X = np.zeros((30, len(seqs), n_letters), dtype=np.uint8)
	for i, sequence in enumerate(seqs):
		for t, letter in enumerate(sequence):
			X[t, i, all_letters.find(letter)] = 1
	return X


# Slice numpy arrays by the given indices and convert them to Variables
def get_seq(ind):
 	if USE_CUDA:
 		return (autograd.Variable(torch.from_numpy(seq_arr[:, ind])).cuda(), 
 				autograd.Variable(torch.from_numpy(cat_vec[ind])).cuda())
 	else:
 		return (autograd.Variable(torch.from_numpy(seq_arr[:, ind]).type(torch.FloatTensor)), 
 				autograd.Variable(torch.from_numpy(cat_vec[ind]).type(torch.LongTensor)))


def timeSince(since):
    now = time.time()
    sec = now - since
    s = sec
    m = math.floor(sec / 60)
    s -= m * 60
    return '%dm %ds' % (m, s), sec


file_paths = glob.glob("./data/names/*.txt")
len_vec = []
cat_vec = []
seq_vec = []
all_categories = []
cat_dict = {}
for file in file_paths:
	cat = file.split('/')[-1].split('.')[0]
	if cat not in cat_dict:
		cat_dict[cat] = len(cat_dict)
	with open(file, encoding='utf-8') as inp:
		for line in inp:
			seq_vec.append(line.strip())
			cat_vec.append(cat_dict[cat])
			all_categories.append(cat)
			len_vec.append(len(line.strip()))


# Sort sequences in descending order by their length
temp = sorted(zip(seq_vec, cat_vec), reverse=True, key = lambda x: len(x[0]))
seq_arr = vectorize([x[0] for x in temp])
cat_vec = np.array([x[1] for x in temp])
len_vec = np.array([len(x[0]) for x in temp])

rnn = RNN(len(cat_dict))

if USE_CUDA:
	rnn = rnn.cuda()

criterion = nn.NLLLoss()
learning_rate = .002
optimizer = torch.optim.RMSprop(rnn.parameters(), lr=learning_rate)

start = time.time()

for iter in range(10):
	all_loss = []
	print(iter+1)

	for ind in range(64):
                # New: put zeroing the gradients here
                optimizer.zero_grad()
        # Batches of size 16, just get some random indices
        # And yes, I know this isn't the right way to make batches
		inds = sorted(np.random.choice(20000, 16, replace=False)) 
		x, y = get_seq(inds)
		x = pack_padded_sequence(x, len_vec[inds])

		rnn.init_hidden(len(inds))
		rnn.zero_grad()
		y_pred = rnn(x, len_vec[inds])

		loss = criterion(y_pred, y)
		loss.backward()

                # Error: we are zeroing gradients after the backward pass
		# optimizer.zero_grad()
		optimizer.step()

		all_loss.append(loss.data[0])

	tstr, sec = timeSince(start)
	print(round(sec / (iter+1), 3))
	print(sum(all_loss))

	print()