Hi there,
I am new to the computational graph in pytorch and reading the blog, How Computational Graphs are Executed in PyTorch, and have several questions about it.
So the first question is about why we need threads in the section, GRAPH TRAVERSAL AND NODE EXECUTION. According to the article, the engine will initialize one thread per device, except for the main thread executing the python interpreter. Does this mean that a thread would exchange data with other threads when computing gradients, that is, executing a computational graph on that thread? Or there is no data exchange, can we say that we have one CPU and one GPU and we don’t move data/tensors to the GPU, the thread on the GPU will be idle and the main thread on CPU would be responsible for executing the computational graph(s)?
The second question is about the ReadyQueue. As far as I know, each thread would have a thread-local ready queue and starts executing a computational graph when the main thread enqueues a RootNode/task into that queue, and then the thread computes gradients as the Figure 5, Animation of the execution of the computational graph, shown. However, the blog also mentions that
Also, if the rest of the graph is not on the cpu, we directly start on that worker while the
RootNode
is always placed on the cpu ready queue.
I just don’t get it that how does the cpu ( or the main thread ) know what is the RootNode for a graph that is not on it? Or the cpu builds the graph, record its RootNode, and move the graph to a GPU? Or something like that?
The last question is about the order of enqueuing nodes when a thread executes a computational graph. Take the one in the Figure 5 as an example. Assume there is no data exchange between threads, like between two GPUs or a CPU and a GPU when executing a graph. After the thread_main executes the graph root, is the order of enqueuing its child nodes, LogBackward and SinBackward, also guaranteed as it’s shown in the animation, that is, enqueue LogBackward alwasy before SinBackward? Or it just depends the actual implementation? And what would this happen if there is data exchange between threads? Just curious.
Regards,