I am using the forward
method below. The code compiles successfully, but I get an error stating that lstm expects two hidden states
.
Any guidance would be appreciated.
#include <ATen/ATen.h>
#include "rnn.h"
std::vector<torch::Tensor> PackedLSTMImpl::flat_weights() const {
// Organize all weights in a flat vector in the order
// (weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0)
// repeated for each layer (next to each other).
std::vector<torch::Tensor> flat;
std::cout << "Flattening weights..." << std::endl;
const auto num_directions = rnn_->options.bidirectional() ? 2 : 1;
for (int64_t layer = 0; layer < rnn_->options.num_layers(); layer++) {
for (auto direction = 0; direction < num_directions; direction++) {
const auto layer_idx =
static_cast<size_t>((layer * num_directions) + direction);
string layer_num = std::to_string(layer);
flat.push_back(rnn_->named_parameters()["weight_ih_l" + layer_num][layer_idx]);
flat.push_back(rnn_->named_parameters()["weight_hh_l" + layer_num][layer_idx]);
if (rnn_->options.bias()) {
flat.push_back(rnn_->named_parameters()["bias_ih_l" + layer_num][layer_idx]);
flat.push_back(rnn_->named_parameters()["bias_hh_l" + layer_num][layer_idx]);
}
// flat.push_back(rnn_->named_parameters()["weight_ih_l0"][layer_idx]);
// flat.push_back(rnn_->named_parameters()["weight_hh_l0"][layer_idx]);
// if (rnn_->options.bias()) {
// flat.push_back(rnn_->named_parameters()["bias_ih_l0"][layer_idx]);
// flat.push_back(rnn_->named_parameters()["bias_hh_l0"][layer_idx]);
// }
}
}
return flat;
}
std::tuple<torch::Tensor, torch::Tensor> PackedLSTMImpl::forward(const torch::Tensor& input,
const at::Tensor& lengths,
torch::Tensor state) {
std::cout << "in PackedLSTMImpl::forward()..." << std::endl;
if (!state.defined()) {
// 2 for hidden state and cell state, then #layers, batch size, state size
// const auto batch_size = input.size(rnn_->options.batch_first_ ? 0 : 1);
const auto max_batch_size = lengths[0].item().toLong();
const auto num_directions = rnn_->options.bidirectional() ? 2 : 1;
state = torch::zeros({2, rnn_->options.num_layers() * num_directions,
max_batch_size, rnn_->options.hidden_size()},
input.options());
}
torch::Tensor output, hidden_state, cell_state;
std::vector<torch::Tensor> tensor_vector = flat_weights();
torch::TensorList flattened_weights[tensor_vector.size()];
for (int i=1; i<tensor_vector.size(); i++)
flattened_weights[i] = tensor_vector[i];
// aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)
// at::lstm(at::Tensor const&, c10::ArrayRef<at::Tensor>, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool)
torch::TensorList hidden_states = {state[0], state[1]};
std::tie(output, hidden_state, cell_state) = at::lstm(
input, lengths, hidden_states, flattened_weights,
rnn_->options.bias(), rnn_->options.num_layers(), rnn_->options.dropout(),
rnn_->is_training(), rnn_->options.bidirectional());
return {output, torch::stack({hidden_state, cell_state})};
}