THC threading support

Hi everyone,

I wasn’t quite sure where to bring this question, so I thought I’d try my luck here. If there is a more appropriate place for me to go, please let me know and I’ll quit pestering this board!

Basically I’m interested in using the THC library outside of PyTorch and as a part of a Rust package (rather than C++). For those of you who don’t know Rust, it is similar in some ways to C++, but the key point is it compiles down to machine language (no interpreters/VMs involved) so it is very easy to import the THC API into Rust and then link against it during compile time. Actually, most of this works in the project I am working on: I can run deep learning models (inference-only right now) in my Rust code using THC (via a thin wrapper for it written in Rust). Everything is gravy except one little thing.

The project I am speaking about has a bunch of unit tests written that test basic functionality (e.g. matrix multiplication, convolutions, LSTM forward passes, etc). When I run the unit tests on one thread (i.e. in serial) the tests work flawlessly. However, when I run the tests in parallel using two or more threads, I sometimes get incorrect results. This really feels like a concurrency issue because right after my failed assertions, I print out the left and right-hand sides of the equality I am testing and I see two equal tensors, implying that some time after the assertion fails and before the printing has occurred, the GPU finished doing what it needed so that the output vector was now correct.

I have gone through the THC and PyTorch code in great detail and I can’t find anything that I am missing about the proper use of the THC API. So, I came here to ask an admittedly broad question: what might I be doing wrong??

  • Is there an obvious gotcha to the use of THC in a multi-threaded environment that I might be missing (e.g. initialization must happen at a specific point with respect to the creation of threads?)

  • Are there any good examples of THC being used in a multi-threading environment (Python examples tend to be multiprocessing, not multi-threading due I assume to the well-known GIL issues)

I am happy to go into some detail about how I am using the API in my code.

Any tips at all will be very helpful.

Thanks!

Hi Joshua,

how are you comparing the two results and how do you tell that “you see equal tensors”? The two tensors you see may differ very slightly. So instead of testing A == B, maybe try sth like norm(A - B) < thresh where thresh is very small, e.g. 1e-9.

Great question. For floating point types I use a threshold approach like you mentioned. For integer types I use straight equality. And remember, the tests work when run in series using one thread, so my gut tells me the tests are sound and that something else is going on.

Any further thoughts on this question?