Currently looking to reset my model for every fold in a Cross Validation when I realized I might initialized my model all wrong. I am very confused right now because it looks like every forward pass in my model resets my hidden state and cell state to zeros. I think I added the init_hidden function when I went from only training with batch size =1 to different batch sizes.
Currently rethinking what I did the last week.
class Model_GRU_1(nn.Module):
def __init__(self, n_features, n_classes, n_hidden, n_layers,dropout):
super().__init__()
self.gru = nn.GRU(
input_size=n_features,
hidden_size=n_hidden,
num_layers=n_layers,
batch_first=True,
dropout=dropout,
bidirectional=True
)
self.dense = nn.Linear(n_hidden, n_hidden)
self.relu = nn.ReLU()
weight = torch.zeros(n_layers,n_hidden)
nn.init.kaiming_uniform_(weight)
self.weight = nn.Parameter(weight)
self.classifier = nn.Linear(n_hidden, n_classes)
def init_hidden(self):
hidden_state = torch.zeros(self.gru.num_layers,batch_size,self.gru.hidden_size)
cell_state = torch.zeros(self.gru.num_layers,batch_size,self.gru.hidden_size)
return (hidden_state, cell_state)
def forward(self, x):
self.hidden = self.init_hidden()
_, (hidden) = self.gru(x)
out=hidden[-1]
out2 = self.dense(out)
out3 = self.relu(out2)
return self.classifier(out3)