I set the initial hidden state in the forward function. Does it work?
class MV_GRU(torch.nn.Module):
def __init__(self, n_features, seq_length, num_hiddens, hidden_layers):
super(MV_GRU, self).__init__()
self.n_features = n_features
self.seq_len = seq_length
self.n_hidden = num_hiddens # number of hidden states
self.n_layers = hidden_layers # number of GRU layers (stacked)
self.l_gru = torch.nn.GRU(input_size=n_features,
hidden_size=self.n_hidden,
num_layers=self.n_layers,
batch_first=True)
self.l_linear = torch.nn.Linear(self.n_hidden * self.seq_len, 1)
def forward(self, x):
batch_size, seq_len, _ = x.size()
self.hidden = torch.zeros(self.n_layers, batch_size, self.n_hidden).to(device)
gru_out, self.hidden = self.l_gru(x, self.hidden)
x = gru_out.contiguous().view(batch_size, -1)
return self.l_linear(x)