I’m trying to stack GRUCell but got an error as above. I didn’t use GRU because the input for each sequence comes from the output (actually, modification of ) previous sequence.
class Stacked_GRU_Cells(nn.Module):
def __init__(self, input_size, hidden_size):
super(Stacked_GRU_Cells, self).__init__()
self.hidden_size = hidden_size
self.gru_0 = nn.GRUCell(input_size, hidden_size)
self.gru_1 = nn.GRUCell(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, input_size)
def forward(self, x, h_in):
if h_in is None:
h_in = torch.zeros(2, x.shape[0], self.hidden_size, device=x.device) # (2, batch_size, input_dim)
h_out = torch.zeros(2, x.shape[0], self.hidden_size, device=x.device)
h_out[0] = self.gru_0(x, h_in[0])
h_out[1] = self.gru_1(h_out[0], h_in[1])
x = self.out(h_out[1])
return x, h_out
def forward_RNN_pass(gru_rnn, input_data, hidden_size):
batch_size = input_data.size(0)
seq_len = 2
# Initialize hidden state
h = torch.zeros(2, batch_size, hidden_size, device=input_data.device)
x = input_data
# Loop over all sequences in batch
for _ in range(seq_len):
# Forward pass through GRU layer
x, h = gru_rnn(x, h)
return x, h
input_dim = 501
hidden_size = 50
gru_rnn = Stacked_GRU_Cells(input_dim, hidden_size)
# define optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for _ in range(2):
running_loss = 0.0
for i, samples in enumerate(train_dataloader):
x, y = samples
print('x', x.shape)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
x_hat, _ = forward_RNN_pass(gru_rnn, x, hidden_size)
loss = criterion(x_hat, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()