Looking at this first line from the output of profile:
Name CPU time CUDA time Calls CPU total CUDA total
----------------------------------- --------------- --------------- --------------- --------------- ---------------
_th_get_device 79.467us 77.472us 1 79.467us 77.472us
I’m seeing lots of calls to this _th_get_device function. In fact, when I post-processed the output with a python script I found that of the top 5 items on GPU time, it was the 2nd highest right after mm (which I assume is matrix multiplication):
Is there any way of reducing this, or is this just the nature of the beast? I don’t find a lot of information about what _th_get_device is actually doing, but assume from the name that it’s probably something to do with the CPU making a request to the GPU, in which case, it’s probably not something that can be optimized much, but is there some way to reduce the number of these calls?
Also, is there somewhere to lookup what the various function calls are doing?
We had a bug where _th_get_device was taking much more time than it should have been, as you’ve encountered here. I’ve fixed this on master (see https://github.com/pytorch/pytorch/issues/13049); you should try a nightly build and see if performance gets better for you.
The best way to look up what the various function calls are doing are to grep the codebase. If you have specific questions about where to find code pointers for specific questions, feel free to ask.
richard, I can’t seem to find many of the function names from the profiler in the pytorch code. For example, if I look at the five functions that take the most time:
I find the string _th_get_device in aten/src/ATen/core/aten_interned_strings.h but there really isn’t a function definition there.
For LSTMFusedBackward, I don’t find that particular string by grepping, but I do find LSTMFused in torch/nn/_functions/thnn/auto.py - apparently LSTMFusedBackward is a function that’s generated in there? I’m guessing that a lot of these functions are actually generated in auto.py. MmBackward is the backwards mm, it seems (for example).