Edge case with register_hook

This may be a fairly rare edge case, but wanted to confirm if this should be possible with the c++ api.
Suppose I have the following program calling backward of a different graph inside the backward hook of another one from a different thread.

I always need to call backward in a different thread because the main thread must be available to execute the hooks, that in my case go back to a an interface that only supports calls from the main thread.

#include <torch/torch.h>
#include <thread>

int main()
{

    auto x = torch::randn(1, torch::requires_grad());
    auto y = 2 * x;

    std::function<void(torch::Tensor)> hook1 = [](torch::Tensor grad) {
        std::cout << "hey" << std::endl;
    };
    x.register_hook(hook1);

    auto a = torch::randn(1, torch::requires_grad());
    auto b = 2 * a;

    std::function<void(torch::Tensor)> hook2 = [&](torch::Tensor grad) {
        std::cout << "hello" << std::endl;
        std::thread t([&]() {
            y.backward();
        });
        t.join();
    };

    a.register_hook(hook2);

    std::thread t2([&]() {
        b.backward();
    });
    // wait for hooks

    t2.join();

    return 0;
}

Hi,

Yes this will work. Do you see any issue with it?

Hi @albanD

Yes, sorry! forgot to mention that this enters a lock state where it seems that torch is waiting a mutex unlock in order to run the backward operation inside the hook.

Ho right.
The thing is that your hook actually waits on the other backward to finish because it waits on the the other thread.
The thing is that because the hook is blocked waiting on this, another thread cannot use run backward (this current thread can though).
So you either want to run this other backward in the same thread as the hook. Or not block the hook waiting on that backward.

OK! Thanks for your reply, I’ll figure out the best way to not spin up a new thread for each backward call.