Greetings PyTorch Community,
I am trying to implement an RNN model using GRU cell. Look for some guidance on why the init_hidden() function returning the error.
class RNN(nn.Module):
def __init__(
self,
batch_size,
e_dim = 200,
n_hidden = 6,
n_class = 6,
n_layers = 1
):
super(RNN, self).__init__()
self.n_layers = n_layers
self.e_dim = e_dim
self.n_hidden = n_hidden
self.batch_size = batch_size
self.n_class = n_class
self.gru = nn.GRU(
self.e_dim,
self.n_hidden,
num_layers = n_layers,
batch_first = True
)
self.fc = nn.Linear(self.n_hidden, self.n_class)
def init_hidden(self):
return torch.randn(self.n_layers, self.batch_size, self.hidden_size)
def forward(self, inputs):
# Avoid breaking if the last batch has a different size
batch_size = inputs.size(0)
if batch_size != self.batch_size:
self.batch_size = batch_size
output, hidden = self.gru(inputs, self.init_hidden())
output = self.fc(output[:, :, -1]).squeeze()
return output