Progress callback // Interrupting execution of `torch::jit::Module::forward()`

Hi!

I’m using a torchscript model (exported from python) to perform a forward pass on a relatively large audio tensor in a C++ program. Because the forward pass may take a considerably long amount of time, I’d like for the user to be able to cancel the forward pass halfway through.

A (rather hacky) workaround for this would be to perform the forward pass in a separate thread, and detach it in case the user wishes to cancel:

torch::Tensor input;
torch::Tensor output;

std::atomic<bool> done = {false};

torch::jit::Module model;

auto thread = std::thread(
   [&model, &input, &output, &done]()
   {
      // forward pass
      std::vector<torch::jit::IValue> inputs = {input};
      output = model->forward(inputs).ToTensor();

      done = true;
   }
);

// wait for the thread to finish
while (!done)
{
   if (UserWantsToCancel())
   {
      // abort if requested
      thread.detach();
      return;
   }

   wxMilliSleep(50);
}

thread.join();

I was wondering if the C++ torch API provided an interface for interrupting the execution of a forward pass, perhaps through some sort of callback or hook?

Any help would be appreciated! Thanks so much!

I am very interested in that too. Seems there simply is no option to terminate a forward() execution midtime.