I have an encoder class and a decoder class(potentially will try to extend it to attention). I want to add one or 2 fully connected layer after decoder and i am confused on how to proceed.
def get_embedding(matrix)
output_size, emb_size = matrix.size()
emb = nn.Embedding(output_size, emb_size)
emb.load_state_dict({'weight': matrix})
return emb, emb_size, output_size
class EncoderRNN(nn.Module):
def __init__(self, embs, hidden_size, n_layers=2):
super(EncoderRNN, self).__init__()
self.emb, emb_size, output_size = get_embedding(embs)
self.n_layers = n_layers
self.hidden_size = hidden_size
self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True, num_layers=n_layers, bidirectional=False)
def forward(self, input, hidden):
return self.lstm(self.emb(input), hidden)
def initHidden(self, batch_size):
return Variable(torch.zeros(self.n_layers, batch_size, self.hidden_size))
class DecoderRNN(nn.Module):
def __init__(self, embs, hidden_size, n_layers=2):
super(DecoderRNN, self).__init__()
self.emb, emb_size, output_size = get_embedding(embs)
self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True, num_layers=n_layers)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, inp, hidden):
emb = self.emb(inp).unsqueeze(1)
res, hidden = self.lstm(emb, hidden)
res = F.log_softmax(self.out(res[:,0]))
return res, hidden
This is the class I am trying to make to combine encoder and decoder with a fully connected but unable to do it.
class Combine(nn.Module):
def __init__(self,inp emb1,emb2 hidden_size, n_layers=2,output_dim):
super(Combine, self).__init__()
self.enc = EncoderRNN(emb1, hidden_size)
self.dec = DecoderRNN(emb2, hidden_size)
self.fc1 = nn.Linear(input_dim, output_dim)
def forward(self,inp):
batch_size, input_length = inp.size()
hidden = self.enc.initHidden(batch_size)
enc_outputs, hidden = self.enc(inp, hidden)
self.dec(...
out = self.fc1(
out =F.log_softmax(out)
I think I need last few lines where am not clear. What will i feed into decoder and then how will this connect to a fully connected layer for a classification probem
The input to the model(encoder will be sentences) and some sentences for the decoder whereas labels or target class will be 2 -3 classes.