Hello, How do I use torch::narrow
in a torch::nn::SequentialImpl
? I am new to the c++ api. My first attempt was to create a Functional
module, which I understand wraps a free function, but this doesn’t compile:
#include <torch/torch.h>
class NarrowingImpl : public torch::nn::SequentialImpl {
public:
NarrowingImpl() {
using namespace torch::nn;
// This was my first attempt, but it doesn't work
push_back(Functional(torch::narrow, 0, 1, 1));
// This works, but with slightly different semantics (end vs length)
// push_back(Functional(functional::_narrow_with_range, 0, 1, 2));
}
};
TORCH_MODULE(Narrowing);
int main() {
Narrowing nrw;
torch::Tensor x = torch::rand({2, 3, 4});
std::cout << x << std::endl << std::endl;
std::cout << nrw->forward(x) << std::endl;
}
As a side note, I was able to get it to work with functional::_narrow_with_range
, but I’m not sure if that is the correct way to go about this.
Any help would be much appreciated. Thank you!