I want to have the standard LSTM/GRU/RNN set up but swap the linear function with a convolution. Is that possible to do in Pytorch in an clean and efficient manner?
Ideally it still works with packing, varying sequence length etc.
Small sample code of a trivial way to pass data through it would be super useful like:
# Based on Robert Guthrie tutorial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pdb import set_trace as st
torch.manual_seed(1)
def step_by_step(net,sequence,hidden):
'''
an example of an LSTM processing all the sequence one token at a time (one time step at a time)
'''
## process sequence one element at a time
print()
print('start processing sequence')
for i, token in enumerate(sequence):
print(f'-- i = {i}')
#print(f'token.size() = {token.size()}')
## to add fake batch_size and fake seq_len
h_n, c_n = hidden # hidden states, cell state
processed_token = token.view(1, 1, -1) # torch.Size([1, 1, 3])
print(f'processed_token.size() = {processed_token.size()}')
print(f'h_n.size() = {h_n.size()}')
#print(f'processed_token = {processed_token}')
#print(f'h_n = {h_n}')
# after each step, hidden contains the hidden state.
out, hidden = lstm(processed_token, hidden)
## print results
print()
print(out)
print(hidden)
def whole_seq_all_at_once(lstm,sequence,hidden):
'''
alternatively, we can do the entire sequence all at once.
the first value returned by LSTM is all of the hidden states throughout
#the sequence.
The second is just the most recent hidden state
(compare the last slice of "out" with "hidden" below, they are the same)
The reason for this is that:
"out" will give you access to all hidden states in the sequence
"hidden" will allow you to continue the sequence and backpropagate,
by passing it as an argument to the lstm at a later time
Add the extra 2nd dimension
'''
h, c = hidden
Tx = len(sequence)
## concatenates list of tensors in the dim 0, i.e. stacks them downwards creating new rows
sequence = torch.cat(sequence) # (5, 3)
## add a singleton dimension of size 1
sequence = sequence.view(len(sequence), 1, -1) # (5, 1, 3)
print(f'sequence.size() = {sequence.size()}')
print(f'h.size() = {h.size()}')
print(f'c.size() = {c.size()}')
out, hidden = lstm(sequence, hidden)
## "out" will give you access to all hidden states in the sequence
print()
print(f'out = {out}')
print(f'out.size() = {out.size()}') # (5, 1, 25)
##
h_n, c_n = hidden
print(f'h_n = {h_n}')
print(f'h_n.size() = {h_n.size()}')
print(f'c_n = {c_n}')
print(f'c_n.size() = {c_n.size()}')
if __name__ == '__main__':
## model params
hidden_size = 6
input_size = 3
lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size)
## make a sequence of length Tx (list of Tx tensors)
Tx = 5
sequence = [torch.randn(1, input_size) for _ in range(Tx)] # make a sequence of length 5
## initialize the hidden state.
hidden = (torch.randn(1, 1, hidden_size), torch.randn(1, 1, hidden_size))
#step_by_step(lstm,sequence,hidden)
whole_seq_all_at_once(lstm,sequence,hidden)
print('DONE \a')
Cross posted: