Hey,
I’m trying to use libtorch to create and save models. Following the DCGAN tutorial and this I create a basic script as given here
class TestImpl: torch::nn::Module
{
public:
TestImpl(int inPlanes, int outClasses)
: m_score(torch::nn::Conv2dOptions(inPlanes, outClasses, 1))
{
m_score->to(at::kCUDA);
register_module("conv",m_score);
}
double forward(at::Tensor& input, at::Tensor& output)
{
cudaDeviceSynchronize();
auto start = std::chrono::system_clock::now();
output = m_score->forward(input);
cudaDeviceSynchronize();
auto end = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_seconds = end-start;
return double(elapsed_seconds.count());
}
private:
torch::nn::Conv2d m_score{nullptr};
};
TORCH_MODULE(Test);
int main(int argc, char** argv )
{
Test net(3,10);
int loop = 100;
int skip = 20;
float timing;
at::Tensor outputs;
at::Tensor inputs = torch::ones({1, 3, 1024, 1024},torch::TensorOptions().device(torch::kCUDA));
for (auto i = 0; i < loop+skip; i++)
{
auto tmp = net->forward(inputs, outputs);//_pass(inputs, outputs);
if(i >= skip)
timing += tmp;
}
std::cout<< "Average forward pass in: " << timing/loop <<std::endl;
torch::save(net, "chkValues.pt");
}
When I comment out the last line (torch::save
) it compiles fine but un-commenting this throws out a lot of errors starting with:
In file included from libtorch_stable/include/torch/csrc/api/include/torch/nn/module.h:3:0,
from libtorch_stable/include/torch/csrc/api/include/torch/nn/cloneable.h:3,
from libtorch_stable/include/torch/csrc/api/include/torch/nn.h:3,
from libtorch_stable/include/torch/csrc/api/include/torch/all.h:7,
fromlibtorch_stable/include/torch/csrc/api/include/torch/torch.h:3,
from timing.cpp:36:
libtorch_stable/include/torch/csrc/api/include/torch/nn/pimpl.h: In instantiation of ‘torch::serialize::OutputArchive& torch::nn::operator<<(torch::serialize::OutputArchive&, const torch::nn::ModuleHolder<ModuleType>&) [with ModuleType = TestImpl]’:
libtorch_stable/include/torch/csrc/api/include/torch/serialize.h:43:11: required from ‘void torch::save(const Value&, SaveToArgs&& ...) [with Value = Test; SaveToArgs = {const char (&)[13]}]’
timing.cpp:127:34: required from here
libtorch_stable/include/torch/csrc/api/include/torch/nn/pimpl.h:179:18: error: no match for ‘operator<<’ (operand types are ‘torch::serialize::OutputArchive’ and ‘const std::shared_ptr<TestImpl>’)
return archive << module.ptr();
Could someone please help me understand why this is happening?
I’m using libtorch-stable 1.4 version with cuda 9.2.