I’m struggling with the parameters optimization of parameters of a seq2seq model on a simple toy dataset with sinusoidal input (integers transformed in one hot encoder) and cosine output and results (see figure). The signal is one large input sliced in chunks which period increase sigltly with time.
Network parameters don’t seem to optimize correctly since the output of the decoder is always a near zeros…
It’s not the first time I’ve problems with pytorch gradient descent, I’m wondering if something in my code is causing gradients bugs…
Any help is really appreciated !!
import torch
from torch import nn, optim
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm
################################################################################
### Model layers (Encoder, Encoder Hidden to Decoder Hidden, Decoder, Seq2Seq)
################################################################################
class EncoderRNN(nn.Module):
'''
Sequence encoder using RNN
'''
def __init__(self, input_size, hidden_size,n_layers=1):
super(EncoderRNN, self).__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size,
bidirectional=True,batch_first=True,
num_layers=n_layers)
def forward(self, input, hidden):
embedded = self.embedding(input.long())
output, hidden = self.gru(embedded, torch.transpose(hidden,1,0))
hidden = torch.transpose(hidden,1,0)
return output, hidden
def initHidden(self,device,batch_size):
return torch.zeros(batch_size,2*self.n_layers,self.hidden_size,dtype=torch.float,device=device)
class EncoderHidden2DecoderHidden(nn.Module):
'''
Concatenate hidden layers from encoder and decoder at the last step
in order to initialise decoder hidden layer
'''
def __init__(self, hidden_size):
super(EncoderHidden2DecoderHidden, self).__init__()
self.hidden_size = hidden_size
self.Linear = nn.Linear(hidden_size*3,hidden_size)
def forward(self,encoder_hidden,decoder_hidden):
encoder_hidden = self.cat_directions(encoder_hidden)
hidden = torch.cat((encoder_hidden,decoder_hidden),2)
hidden = self.Linear(hidden)
return hidden
def cat_directions(self,h):
"""
If the encoder is bidirectional, do the following transformation.
(#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size)
"""
return torch.cat([h[:,0:h.size(1):2], h[:,1:h.size(1):2]], 2)
class AttnDecoderRNN(nn.Module):
'''
Sequence decoder using RNN and attention over encoder outputs
'''
def __init__(self, input_size,hidden_size, input_seq_len, output_size,n_layers=1,dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.input_seq_len = input_seq_len
self.n_layers = n_layers
self.embedding_prevout = nn.Linear(self.output_size, self.output_size)
self.attn = nn.Linear(self.hidden_size * (n_layers)+1, self.input_seq_len)
self.attn_combine = nn.Linear(self.hidden_size *2 + self.output_size, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size,
self.hidden_size,
batch_first=True,
num_layers=n_layers)
self.out = nn.Linear(self.hidden_size, self.output_size)
self.Relu = nn.ReLU()
def forward(self,prev_out, hidden0, encoder_outputs):
prev_out = self.embedding_prevout(prev_out)
if hidden0.size(0)>1:
hidden = torch.cat([h.contiguous().view(1,1,-1) for h in hidden0],0)
else:
hidden = hidden0.view(1,1,-1)
attn_weights = F.softmax(self.attn(torch.cat((prev_out, hidden), 2)), dim=2)
attn_applied = torch.bmm(attn_weights,torch.transpose(encoder_outputs,1,2))
output = torch.cat((prev_out, attn_applied), 2)
output = self.attn_combine(output)
output = self.Relu(output)
hidden0 = torch.transpose(hidden0,0,1)
output, hidden = self.gru(output, hidden0)
hidden = torch.transpose(hidden,0,1)
output = self.out(self.Relu(output))
return output, hidden, attn_weights
def initHidden(self,device,batch_size):
return torch.zeros(batch_size,self.n_layers, self.hidden_size,dtype=torch.float,device=device)
class DecoderRNN(nn.Module):
'''
Decoder WITHOUT attention (for comparison)
'''
def __init__(self, hidden_size, output_size,n_layers):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding = nn.Linear(output_size, hidden_size)
self.gru = nn.GRU(self.hidden_size,
self.hidden_size,
batch_first=True,
num_layers=n_layers)
self.out = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden,encoder_outputs=[]):
output = self.embedding(input).view(1, 1, -1)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = self.out(output[0])
return output, hidden,0
class Seq2Seq(nn.Module):
'''
Seq2Seq model with attention
'''
def __init__(self, input_size=5,
hidden_size=32,
output_size=1,
dropout_p=0.2,
batch_size=1,
n_layers=1,
seq_len_input=100,
seq_len_output=100,
use_attention=True,
phase='',
device=''):
super(Seq2Seq, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.n_layers = n_layers
self.device = device
self.batch_size = batch_size
self.n_layers = n_layers
self.seq_len_input = seq_len_input
self.seq_len_output = seq_len_output
self.phase = phase
self.encoder = EncoderRNN(input_size, self.hidden_size,n_layers=self.n_layers)
#Attention model vs standard RNN
if use_attention==True:
self.decoder = AttnDecoderRNN(input_size,self.hidden_size,
self.seq_len_input,
self.output_size,
dropout_p=dropout_p,
n_layers=self.n_layers)
else:
self.decoder = DecoderRNN(self.hidden_size, output_size,n_layers)
self.EH2DH = EncoderHidden2DecoderHidden(self.hidden_size)
#Save encoder and decoder hidden vector to propagate through all the sequence
self.encoder_hidden_init = torch.zeros(self.batch_size,2*self.n_layers,self.hidden_size,dtype=torch.float,device=self.device)
self.decoder_hidden_init = torch.zeros(self.batch_size,self.n_layers, self.hidden_size,dtype=torch.float,device=self.device)
#Save decoder previous output
self.decoder_input_init = torch.zeros(self.batch_size,1, 1,dtype=torch.float,device=self.device)
def forward(self,input_tensor,target_tensor,use_teacher_forcing):
'''
Model forward pass: Encoder -> Decoder with attention
'''
encoder_outputs = torch.zeros(self.batch_size,self.hidden_size*2,self.seq_len_input, dtype=torch.float, device=self.device)
decoder_outputs = torch.zeros(self.batch_size,1,self.seq_len_output, dtype=torch.float, device=self.device)
#Record attention
if self.phase=='pred':
decoder_attentions = torch.zeros(self.seq_len_output, self.seq_len_input, dtype=torch.float, device=self.device)
# Encoder
encoder_hidden = self.encoder_hidden_init #Encoder hidden layer initialized
for ei in range(self.seq_len_input):
encoder_output, self.encoder_hidden = self.encoder(input_tensor[:,:,ei], encoder_hidden)
encoder_outputs[:,:,ei] = encoder_output[0, 0]
if ei==self.seq_len_output-1:
self.encoder_hidden_init = encoder_hidden.detach()
# Initialize decoder hidden using encoder hidden and decoder hidden at previous step
decoder_hidden = self.EH2DH(encoder_hidden,self.decoder_hidden_init)
#Initialize decoder input
decoder_input = self.decoder_input_init
# Teacher forcing: Feed the target as the next input
if use_teacher_forcing and self.phase=='train':
print('Teacher forcing')
for di in range(self.seq_len_output):
decoder_output, decoder_hidden, decoder_attention = self.decoder(
decoder_input, decoder_hidden, encoder_outputs)
decoder_input = target_tensor[:,:,di].unsqueeze(2) # Teacher forcing
decoder_outputs[:,0,di] = decoder_output.squeeze()
# Without teacher forcing: use its own predictions as the next input
else:
# print('No Teacher forcing')
for di in range(self.seq_len_output):
decoder_output, decoder_hidden, decoder_attention = self.decoder(
decoder_input, decoder_hidden, encoder_outputs)
decoder_input = decoder_output.detach() # detach from history as input for next step
decoder_outputs[:,0,di] = decoder_output.squeeze()
#Save decoder input for next sequence
self.decoder_input_init = decoder_output.detach()
#Capture the last hidden state to initalize next sequce
self.decoder_hidden_init = decoder_hidden.detach()
if self.phase=='train' or self.phase=='val':
return decoder_outputs
if self.phase=='pred':
decoder_attentions = torch.cat([decoder_attention.detach().data for decoder_attention in decoder_attentions],0)
return decoder_outputs.detach().data,decoder_attentions
def initHidden(self):
self.encoder_hidden_init = torch.zeros(self.batch_size,2*self.n_layers,self.hidden_size,dtype=torch.float,device=self.device)
self.decoder_hidden_init = torch.zeros(self.batch_size,self.n_layers, self.hidden_size,dtype=torch.float,device=self.device)
##############################################################
### Dataset
##############################################################
class toy_dataset(Dataset):
def __init__(self,device,batch_size,seqlen):
super(toy_dataset, self).__init__()
seqlen=seqlen
self.num_samples=50
x_np = np.zeros((batch_size,seqlen,self.num_samples))
y_np = np.zeros((batch_size,seqlen,self.num_samples))
for b in range(batch_size):
r = np.floor(np.random.rand(1)*5+3)
for k in range(self.num_samples):
steps = np.linspace(r+k*np.pi,r+(k+1)*np.pi, seqlen, dtype=np.float32)
x_np[b,:,k] = np.floor(np.sin(steps*(k+1)*0.5)*2+2)
y_np[b,:,k] = np.cos(steps*r*(k+1)*0.5)
self.x = torch.from_numpy(x_np).long().to(device)
self.y = torch.from_numpy(y_np).float().to(device)
self.x.require_grad=True
self.index=0
def __iter__(self):
return self
def __next__(self):
while True:
if self.index==self.num_samples:
self.index=0
raise StopIteration()
x = self.x[:,:,self.index].unsqueeze(1)
y = self.y[:,:,self.index].unsqueeze(1)
self.index+=1
return x,y
def __len__(self):
return self.num_samples
def trch2npy(trch,device):
'''
Convert torch to numpy array. Auxilliary function
'''
if device.type=="cpu":
return trch.data.numpy()
else:
return trch.cpu().data.numpy()
##############################################################
### Model training class
##############################################################
class Seq2Seq_model:
'''
Seq2seq architecture with attention. This class allows to train the network and save
parameters along the training
inputs:
input_size=5 : Input vocabulary size
output_size=1: size of the output: 1-d vector
Nepochs: Number of epochs
batch_size: batch size
seq_len_output: hist: Lenght of the predicted histone sequence
seq_len_input: Lenght of the DNA sequence used to predict centered aroud input sequence middle.
hidden_size: size of the RNN hidden vectors
dropout_p: Dropout of the attention over the previous decoder output
n_layers: Number of layers per RNN
'''
def __init__(self,
input_size,
output_size,
batch_size=1,
seq_len_input=100,
seq_len_output=100,
hidden_size=32,
dropout_p=0.2,
use_attention=True,
n_layers=1):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.n_layers = n_layers
self.hidden_size = hidden_size
self.Seq2Seq = Seq2Seq( input_size=input_size,
hidden_size=hidden_size,
output_size=output_size,
dropout_p=dropout_p,
batch_size=batch_size,
n_layers=n_layers,
device=self.device,
seq_len_input=seq_len_input,
seq_len_output=seq_len_output,
use_attention=use_attention,
phase='')
self.criterion = nn.MSELoss()
if self.device.type=='cuda':
self.Seq2Seq.cuda()
self.criterion.cuda()
self.teacher_forcing_ratio=0.5
self.Seq2Seq_optimizer = optim.Adam(self.Seq2Seq.parameters(),lr=0.0001,betas=(0.9,0.999))
self.train_loader = toy_dataset(self.device,self.batch_size,seq_len_input)
self.val_loader = toy_dataset(self.device,self.batch_size,seq_len_input)
self.train_len = len(self.train_loader)
self.val_len = len(self.val_loader)
def train(self):
'''
Train the model on training set
'''
for self.epoch in tqdm(range(10),desc='train'):
self.Seq2Seq.train()
self.Seq2Seq.phase='train'
self.Seq2Seq.initHidden()
for self.counter,(input_tensor,target_tensor) in enumerate(self.train_loader):
input_tensor.require_grad=True
self.forward(input_tensor,target_tensor)
self.backward()
self.validate()
def forward(self,input_tensor,target_tensor):
'''
Forward input through the network
'''
#Gradient initialization
self.Seq2Seq_optimizer.zero_grad()
if self.Seq2Seq.phase=='train' or self.Seq2Seq.phase=='val':
# Teacher forcing: Use output given by the network at previous step or the ground truth value
use_teacher_forcing = True if torch.rand(1) < self.teacher_forcing_ratio else False
decoder_outputs = self.Seq2Seq(input_tensor,target_tensor,use_teacher_forcing)
self.loss = self.criterion(decoder_outputs.float(), target_tensor.float())
else:
return self.Seq2Seq(input_tensor,[],False)
def backward(self):
'''
Model backward
'''
self.loss.backward()
self.Seq2Seq_optimizer.step()
def validate(self):
'''
Compute loss score on validation set
'''
self.Seq2Seq.eval()
self.Seq2Seq.phase='val'
self.Seq2Seq.initHidden()
# print("Validation")
for self.counter,(input_tensor,target_tensor) in enumerate(self.val_loader):
with torch.no_grad():
self.forward(input_tensor,target_tensor)
def predict(self):
'''
Compute decoder output and attention score
'''
self.Seq2Seq.eval()
self.Seq2Seq.phase='pred'
decoder_outputs = []
decoder_attentions = []
for self.counter,(input_tensor,target_tensor) in enumerate(tqdm(self.val_loader,total=self.val_len,desc='prediction on validation set')):
if self.counter==self.val_len:
break
with torch.no_grad():
decoder_output,decoder_attention = self.forward(input_tensor,[])
decoder_outputs.append(trch2npy(decoder_output,self.device))
decoder_attentions.append(trch2npy(decoder_attention,self.device))
decoder_outputs = np.concatenate(decoder_outputs,axis=0)
return decoder_outputs,decoder_attentions
##############################################################
### Model Training
##############################################################
import numpy as np
import scipy
import matplotlib.pyplot as plt
s2s = Seq2Seq_model(input_size=5,
output_size=1,
batch_size=1,
seq_len_input=50,
seq_len_output=50,
hidden_size=32,
dropout_p=0.2,
n_layers=1,
use_attention=True)
s2s.train()
decoder_outputs, attention_list = s2s.predict()
# Get validation Inputs and outputs
xs=[]
ys=[]
for x,y in s2s.val_loader:
xs.append(x)
ys.append(y)
xs = torch.cat(xs,0).cpu().numpy()
ys = torch.cat(ys,0).cpu().numpy()
#Display
f,axes = plt.subplots(5,2,figsize=(10,20))
for m,ax in enumerate(axes.flatten()):
ax.plot(decoder_outputs[m,0,:],'r',label='model output')
ax.plot(ys[m,0,:],'b',label='ground truth')
ax.plot(xs[m,0,:],'g',label='input')
ax.legend()
ax.set_xlabel('time')
plt.show()