I am setting up a pyramidal RNN (like the listen attend and spell paper). It seems to me that the packed_sequence API calls is not really feasible here because we shrink the number of timesteps with each layer. Am I right in assuming that the padding will have to be done manually, and then we mask/ignore the zeros?
See this try for setting up the pyramidal stack.
classclass Pyramidal_GRUPyramida (nn.Module):
def init(self,input_size, hidden_size, seq_len=8, stack_size=3):
super(Pyramidal_GRU, self).init()
self.input_size = input_size
self.hidden_size = hidden_size
self.seq_len=seq_len
self.stack_size = 3
self.gru0 = nn.GRU(self.input_size, self.hidden_size,batch_first=True) #initial projection
self.pyramid = nn.ModuleList(
[nn.GRU(2 * hidden_size, hidden_size) for _ in range(stack_size)])
def forward(self,input):
x,hidden = self.gru0(input)
print('x.size',x.size())
seq_len = self.seq_len
for i in range(self.stack_size):
x = x.contiguous().view(-1,seq_len/2,2*self.hidden_size) #need 'contiguous' or it errors out
print('reshaped',x.size())
seq_len /=2
x,_ = self.pyramid[i](x)
print('final',x.size())