Example of Many-to-One LSTM

Hi everyone,

Is there an example of Many-to-One LSTM in PyTorch? I am trying to feed a long vector and get a single label out. An LSTM or GRU example will really help me out.

My problem looks kind of like this:
Input = Series of 5 vectors, output = single class label prediction:

Thanks!

10 Likes

Recurrent modules from torch.nn will get an input sequence and output a sequence of the same length. Just take the last element from that output sequence.

Here is a small working example with a 2-layer LSTM neural network:

import torch
import torch.nn as nn
from torch.autograd import Variable

time_steps = 10
batch_size = 3
in_size = 5
classes_no = 7

model = nn.LSTM(in_size, classes_no, 2)
input_seq = Variable(torch.randn(time_steps, batch_size, in_size))
output_seq, _ = model(input_seq)
last_output = output_seq[-1]

loss = nn.CrossEntropyLoss()
target = Variable(torch.LongTensor(batch_size).random_(0, classes_no-1))
err = loss(last_output, target)
err.backward()
21 Likes

I am still confused about the recurrent recurrent modules output. Its shape is output (seq_len, batch, hidden_size * num_directions):

  • Why the output shape has a hidden_size as a factor?
  • What is the num_directions?
  • What if I want the output sequence to have a specific size (number of features)?
1 Like

Hello @osm3000,

The output per time step has dimension hidden_size per direction (it’s the hidden layer). In LSTM the output is “modulated” cell state.

This is 1 for “usual” LSTM and 2 for bidirectional ones.

Commonly, you would then use hidden_size as the target size and use the last (time direction) output per batch item, i.e.

x = output[-1] # this is batch * hidden_size for unidirectional LSTM

You can then use x as input into whatever layer you want to have above the LSTM.

Best regards

Thomas

8 Likes
  • hidden_size represents the output size of the last recurrent layer. I guess it’s called hidden_size as the output of the last recurrent layer is usually further transformed (as in the Elman model referenced in the docs). Also, if there are several layers in the RNN module, all the hidden ones will have the same number of features: hidden_size.

  • num_directions is either 1 or 2 depending on the boolean argument bidirectional. If the sequence is processed in both directions you’ll get two values for each time step.

  • If you need a fixed number of output features you either set hidden_size to that value, or you add an output layer that maps from hidden_size to your output space.

4 Likes

I am confused about when to give entire sequence as input to a RNN (specifcally LSTM) and when to give input stepwise. The way you suggested in this example seems more intuitive but the example given in NLP From Scratch: Translation with a Sequence to Sequence Network and Attention — PyTorch Tutorials 2.1.1+cu121 documentation takes only one word at a time.

While at first it seems that the machine translation example does so, to get output at each time stamp (i.e. encoded version after giving each word), but wouldn’t the output sequence record all the time stamps anyway?

I tried a small example by giving the entire input sequence at a time (like in the example you mentioned), and giving one word at a time (like in the MT example). The output at each time stamps are different in both cases. Shouldn’t they ideally be similar.

1 Like

@FuriouslyCurious I also have same problem and same case with you. Did you solve that? I am using this to solve the problem Code . I just change the sequence_length and input_size . However I don’t know whether is it true like your problem or not.
@Tudor_Berariu What is the number_of_layer in LSTM?, How to make it becomes stacked LSTM?, or does the layer means we stacked the LSTM?

2 Likes

Because I need practice with LSTMs, here’s my go at solving the problem – 5 samples of 5 rows x 5 columns; one label per sample … Any critique of silly mistakes or bad practices appreciated greatly:

import torch
import torch.nn as nn
import numpy as np

EPOCHS = 500
IN_SIZE = 5
NUM_SAMPLES = 5

def generate_data(rows, columns, samples):
	X = []
	y = []
	transformations = {
		'11': lambda x, y: x + y,
		'15': lambda x, y: x - y,
		'10': lambda x, y: x * y,
		'30': lambda x, y: x / y,
		'2': lambda x, y: x + y,
		}
	for j in range(samples):
		data_set = []
		for i in range(columns):
			data = []
			for val, fn in transformations.items():
				data.append(int(fn(int(val), i+j+1)))
			data_set.append(data)
		X.append(data_set)
		y.append([j+1])
	return X, y


class RNN(nn.Module):
	def __init__(self):
		super(RNN, self).__init__()

		self.rnn = nn.LSTM(
			input_size=5,
			hidden_size=NUM_SAMPLES+1,
			num_layers=2,
			batch_first=True,
		)

	def forward(self, x):
		out, (h_n, h_c) = self.rnn(x, None)
		return out[:, -1, :]	# Return output at last time-step


X, y = generate_data(IN_SIZE, 5, NUM_SAMPLES)
X = torch.FloatTensor(X)
y = torch.LongTensor(y)

rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()     


for j in range(EPOCHS):
	for i, item in enumerate(X):
		item = item.unsqueeze(0)
		output = rnn(item)
		loss = loss_func(output, y[i])
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	if j % 5 == 0:
		print('Loss: ', np.average(loss.detach()))

print('Testing:\n========')
for i, item in enumerate(X):
	print(y[i])
	outp = rnn(item.unsqueeze(0))
	print(np.argmax(outp.detach()))

The Pytorch RNN implementation confused me very badly! but I guess I finally figured it out.
First of all, lets see how a simple RNN does its job. The formula for a vanilla RNN is as follows:

$$ h_t = tanh(W_1 . X_t +W_2 . h_{t-1}) $$

This as we all know is the formula to calculate the new hidden state, for the output part!, however a second step is needed, which is as follows :

$$ o_t = softmax(W_3 \cdot h_t ) $$
W_3 has the shape of (output_size, hidden_size) which after being multiplied by hidden_state (h_t witch shape of (hiddensize, 1) will result in the output vector of shape (outputsize, 1).

Pytorch, doesnt calculate the output by default, so it is up to the user to write it down! RNNs in Pytorch return two results one usually called output and theother hiddenstate, what it really returns as output is actually all the hidden_states for all the timesteps (first result) and the final hidden_state as the second result. This is also the case for LSTM!
This is also indicated in the documentation as it reads :

# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time

So how do we do it in Pytorch then?
So basically in order to have the output of your choice you should have a nn.Linear module with the specified output size as its out_features and hiddensize as its in_features .
So basically lets say, we want a havea many to many RNN configuration.
Our input dim is 26, our output is 26 as there are 26 letters in the alphabet for the sake of our example, (both one hot encoded), and our hidden_size is lets say 100. we would then write :

class ou_rnn(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size, num_layer):
        super().__init__()

        self.rnn = nn.RNN(input_size = input_dim, hidden_size = hidden_size, num_layers = num_layer)
        # you specify your output size in the linear model. 
        self.fc = nn.Linear(hidden_size, output_dim)
  
   def forward(self, x, h):
        out,h = self.rnn(x,h)
        out = out.contiguous().view(-1, hidden_size)
        out = self.fc(out)
        return out,h

and in the training loop, the only thing that needs to be taken care of is the labels (they should be reshaped as well to have the shape label.view(batch_size * sequence_length) when used in CrossEntropyLoss
Thats it!
As you can see, you can easily have any kind of RNN(or LSTM) configuration. many to many, or many to one, or what ever!

IMHO, the source for all of these issues is the misleading naming that is being used in Pytorch. instead of calling all the hidden_states as outputs, simply refer to them as all_hidden_states!

Hope this is useful.
in case there is something that I missed, please correct me.
I also found this to be very helpful

5 Likes