From Tensorflow to Pytorch (while_loop)

Hello everybody,

I am trying to rewrite a simulation code written with Tensorflow using Pytorch.
I am new to Pytorch and I am still learning to work with tensors in general.

I am stuck at rewriting tf.while_loop(), which, as I managed to understand, is a special function in Tensorflow:

t_out,v_out,s_out,m_out,d_out,f_s_out = tf.while_loop(\
            c,
            b,
            loop_vars=[t0,V,S,m0,d0,full_stim],
            shape_invariants=[t0.get_shape(),
                              tf.TensorShape([self.N*4, None]),
                              tf.TensorShape([self.N*4, self.syn_size]),
                              tf.TensorShape([self.N*4, 1])
                              tf.TensorShape([self.N*4, 1]),
                              tf.TensorShape([self.N*4,None])])
        return v_out,s_out,m_out,d_out

What is the equivalent of this in Pytorch?

Another thing, how do I run a simulation? In Tensorflow it is with this line:

v_sim,s_sim,m_sim,d_sim = tf.Session(config=self.config).run([self.v_out,
                                                                                        self.s_out,
                                                                                        self.m_out,
                                                                                        self.d_out],
                                                                        feed_dict={self.V:        self.v_state,
                                                                                   self.S:        self.s_state,
                                                                                   self.m0:       self.m_state,                                                           
                                                                                   self.d0:       self.d_state,
                                                                                   self.s_offset: 0,
                                                                                   self.stim:     stimulus})

How do I do something similar in Pytorch?

Thank you in advance for your help and I am sorry if these questions seem stupid.

You can use the Python while loop in PyTorch directly.

PyTorch doesn’t use sessions and will be eagerly executed. I.e. if you run a PyTorch operation (in the default mode), the result is directly available (similar to e.g. numpy).

When you say “You can use the Python while loop in PyTorch directly.”, does the while loop run in the GPU, or in the CPU? Because if it runs on the CPU, at each step the CPU may need to communicate with the GPU, which can take significant time.

The while loop would run on the CPU, but note that no communication or synchronization with the GPU is needed unless you are using data-dependent control flow or access any CUDATensor otherwise in the loop.

For example, when you run autoregressive models (inference), like GPT-2 or T5, you need to each time take the generated token, and place it at the end of the input tensor. This communication has latency, and when you multiply it by the number of generated tokens, may take significant time. Do you know if the TF while_loop runs on the GPU?

I might not understand the issue you are describing since you mention a token addition to the input tensor (which seems to be a valid operation performed on the GPU such as torch.cat) but are concerned about the “while loop”. Operations inside the loop will be launched by the CPU on the corresponding device (e.g. the GPU assuming all tensors are placed there) while the actual Python code (including the while statement) are executed on the CPU.

My question is if the while-loop logic can be run on the GPU to save time. Let’s say the latency of communication between CPU and GPU is 1ms and we output 100 tokens. The GPU runs the model, and output logits. The CPU samples from the logits with some strategy, e.g. top-k and maybe also some kind of beam search etc, and decides on the output token for this step. Then it takes the token and send it back to the GPU for the next timestamp prediction. If we output 100 tokens, only the communication of this whole process takes 100ms. If this part of the logic was on the GPU, we would have saved this communication time. I hope I explained myself correctly.

No, since the loop is pure Python code and is thus executed on the CPU. The CPU executes Python code and launches CUDAKernels, which then execute the GPU workload.

This seems to be independent from the while loop logic and describes a mixture of GPU and CPU workloads. To allow the CPU to perform any operation on the output of the model (stored on the GPU) a synchronization will be added. However, I still don’t see the connection to the loop as it sounds as if you are thinking about a) a “while loop kernel” or something similar which b) then launches CUDAKernels from the GPU itself?
Maybe related: if you want to avoid the CPU overhead added to each kernel launch, you could check if CUDA Graphs is compatible with your workload as described here.

1 Like