Hi!
I’m using a torchscript model (exported from python) to perform a forward pass on a relatively large audio tensor in a C++ program. Because the forward pass may take a considerably long amount of time, I’d like for the user to be able to cancel the forward pass halfway through.
A (rather hacky) workaround for this would be to perform the forward pass in a separate thread, and detach it in case the user wishes to cancel:
torch::Tensor input;
torch::Tensor output;
std::atomic<bool> done = {false};
torch::jit::Module model;
auto thread = std::thread(
[&model, &input, &output, &done]()
{
// forward pass
std::vector<torch::jit::IValue> inputs = {input};
output = model->forward(inputs).ToTensor();
done = true;
}
);
// wait for the thread to finish
while (!done)
{
if (UserWantsToCancel())
{
// abort if requested
thread.detach();
return;
}
wxMilliSleep(50);
}
thread.join();
I was wondering if the C++ torch API provided an interface for interrupting the execution of a forward pass, perhaps through some sort of callback or hook?
Any help would be appreciated! Thanks so much!