From Tensorflow to Pytorch (while_loop)

Hello everybody,

I am trying to rewrite a simulation code written with Tensorflow using Pytorch.
I am new to Pytorch and I am still learning to work with tensors in general.

I am stuck at rewriting tf.while_loop(), which, as I managed to understand, is a special function in Tensorflow:

t_out,v_out,s_out,m_out,d_out,f_s_out = tf.while_loop(\
            c,
            b,
            loop_vars=[t0,V,S,m0,d0,full_stim],
            shape_invariants=[t0.get_shape(),
                              tf.TensorShape([self.N*4, None]),
                              tf.TensorShape([self.N*4, self.syn_size]),
                              tf.TensorShape([self.N*4, 1])
                              tf.TensorShape([self.N*4, 1]),
                              tf.TensorShape([self.N*4,None])])
        return v_out,s_out,m_out,d_out

What is the equivalent of this in Pytorch?

Another thing, how do I run a simulation? In Tensorflow it is with this line:

v_sim,s_sim,m_sim,d_sim = tf.Session(config=self.config).run([self.v_out,
                                                                                        self.s_out,
                                                                                        self.m_out,
                                                                                        self.d_out],
                                                                        feed_dict={self.V:        self.v_state,
                                                                                   self.S:        self.s_state,
                                                                                   self.m0:       self.m_state,                                                           
                                                                                   self.d0:       self.d_state,
                                                                                   self.s_offset: 0,
                                                                                   self.stim:     stimulus})

How do I do something similar in Pytorch?

Thank you in advance for your help and I am sorry if these questions seem stupid.

You can use the Python while loop in PyTorch directly.

PyTorch doesn’t use sessions and will be eagerly executed. I.e. if you run a PyTorch operation (in the default mode), the result is directly available (similar to e.g. numpy).