class rnn(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(rnn, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.rnn = nn.LSTM(input_size, hidden_size=hidden_size, num_layers=num_layers)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x, hidden):
out, hidden = self.rnn(x, hidden)
output = self.fc(out)
return output, hidden
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
if (use_gpu):
hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_().cuda(),
weight.new(self.num_layers, batch_size, self.hidden_size).zero_().cuda())
else:
hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
return hidden
model = rnn(input_size, hidden_size, num_layers)
print(model)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(1):
model.cuda()
loss = 0
inputs, targets = otu_handler.get_N_samples_and_targets(batch_size, seq_len)
hidden = model.init_hidden(batch_size)
# convert inputs and targets to tensors
tensor1 = torch.FloatTensor(inputs)
tensor2 = torch.FloatTensor(targets).unsqueeze(0)
# send the tensors to cuda
input_tensor = Variable(tensor1.cuda(), requires_grad = True)
input_tensor = input_tensor.transpose(1,2).transpose(0,1) # reshape input tensor to feed to forward
targets = Variable(tensor2.cuda(), requires_grad = False) # send targets tensor to cuda
out, hidden = model(input_tensor, hidden)
loss += criterion(out, targets)
print("epoch: %d, loss: %0.2f" % (epoch +1, loss.item()))
optimizer.zero_grad()
loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(),2)
optimizer.step()
The output is something like this:
epoch: 1, loss: 17322482.00
epoch: 2, loss: 13263158.00
epoch: 3, loss: 17829800.00
epoch: 4, loss: 2325517312.00
epoch: 5, loss: 25319476.00
epoch: 6, loss: 18578794.00
epoch: 7, loss: 12342423.00
epoch: 8, loss: 168415920.00
epoch: 9, loss: 17259250.00
epoch: 10, loss: 175974816.00
epoch: 11, loss: 31784454.00
epoch: 12, loss: 17423422.00
epoch: 13, loss: 20637374.00
epoch: 14, loss: 36012860.00
epoch: 15, loss: 27672722.00
epoch: 16, loss: 19312454.00
epoch: 17, loss: 91513688.00
epoch: 18, loss: 1654813952.00
epoch: 19, loss: 19895126.00
epoch: 20, loss: 60809964.00
epoch: 21, loss: 20559496.00
epoch: 22, loss: 18604082.00
epoch: 23, loss: 18246324.00
epoch: 24, loss: 36698088.00
epoch: 25, loss: 21916944.00
epoch: 26, loss: 26092220.00
epoch: 27, loss: 17202180.00
epoch: 28, loss: 20631326.00
epoch: 29, loss: 22352708.00
epoch: 30, loss: 17544972.00
epoch: 31, loss: 19844386.00
epoch: 32, loss: 3089386496.00
epoch: 33, loss: 21927742.00
epoch: 34, loss: 19233062.00
epoch: 35, loss: 24233808.00
epoch: 36, loss: 14247420.00
epoch: 37, loss: 19866096.00
epoch: 38, loss: 19247676.00
epoch: 39, loss: 40788848.00
epoch: 40, loss: 178087904.00
epoch: 41, loss: 32632774.00
epoch: 42, loss: 49278888.00
epoch: 43, loss: 13424591.00
epoch: 44, loss: 13337734.00
epoch: 45, loss: 17201510.00
epoch: 46, loss: 44591204.00
epoch: 47, loss: 25328970.00
epoch: 48, loss: 14413733.00
epoch: 49, loss: 22293836.00
epoch: 50, loss: 23427574.00
epoch: 51, loss: 189332624.00
epoch: 52, loss: 26622992.00
epoch: 53, loss: 47797516.00
epoch: 54, loss: 45296728.00
epoch: 55, loss: 41071708.00
epoch: 56, loss: 25053186.00
epoch: 57, loss: 27240572.00
epoch: 58, loss: 33122594.00
epoch: 59, loss: 14874048.00
epoch: 60, loss: 20430304.00
epoch: 61, loss: 21469500.00
epoch: 62, loss: 15457670.00
epoch: 63, loss: 17139502.00
epoch: 64, loss: 17082172.00
epoch: 65, loss: 26391324.00
epoch: 66, loss: 40719556.00
epoch: 67, loss: 18023896.00
epoch: 68, loss: 16934692.00
epoch: 69, loss: 26133756.00
epoch: 70, loss: 14400602.00
epoch: 71, loss: 17984878.00
epoch: 72, loss: 926914624.00
epoch: 73, loss: 21649504.00
epoch: 74, loss: 16226421.00
epoch: 75, loss: 15451624.00
epoch: 76, loss: 22588744.00
epoch: 77, loss: 42169820.00
I know there is some logical error in the code that i am not able to pin-point.
Any help is appreciated.
Thank you