It looks like calls to _th_get_device are some of the most expensive ops according to the Pytorch profiler. It’s called a couple thousand times and I don’t see anything that’d obviously do this. I thought checking the device type of a tensor was a culprit, but I’m not doing that anywhere. I only cast as a CudaTensor a few times, not thousands.
Which methods call this op (_th_get_device)?
UPDATE: It seems like slice might be checking the device since it needs to allocate a new tensor. Is that unavoidable?
Here’s an idea of what is being shown by the profiler: https://gist.github.com/mrdrozdov/2fc8fbf0ed2a77ea2f568bcea441bb97
The ops slice, cat, and _th_get_device are the most costly.