Convert tensorflow source code to pytorch source code

I tried to convert a routing process source code for “Capsule Network-based Model for Learning Node Embeddings” written in TensorFlow to PyTorch. But I don’t know if the conversion is correct?
the code in TensorFlow:

def routing(input, b_IJ, batch_size, iter_routing, num_caps_i, num_caps_j, len_u_i, len_v_j):
    # W: [num_caps_i, num_caps_j, len_u_i, len_v_j]
    W = tf.get_variable('Weight', shape=(1, num_caps_i, num_caps_j, len_u_i, len_v_j), dtype=tf.float32,
                        initializer=tf.random_normal_initializer(stddev=0.01, seed=1234))

    # Eq.2, calc u_hat
    # do tiling for input and W before matmul
    input = tf.tile(input, [1, 1, num_caps_j, 1, 1])
    W = tf.tile(W, [batch_size, 1, 1, 1, 1])

    # in last 2 dims:
    u_hat = tf.matmul(W, input, transpose_a=True)
    # In forward, u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat
    u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')

    # line 3,for r iterations do
    for r_iter in range(iter_routing):
        with tf.variable_scope('iter_' + str(r_iter)):
            # line 4:
            c_IJ = tf.nn.softmax(b_IJ, axis=1) * num_caps_i #axis=1 # original code

            if r_iter == iter_routing - 1:
                # line 5:
                # weighting u_hat with c_IJ, element-wise in the last two dims
                s_J = tf.multiply(c_IJ, u_hat)
                # then sum in the second dim, resulting in [batch_size, 1, num_caps_j, len_v_j, 1]
                s_J = tf.reduce_sum(s_J, axis=1, keepdims=True)
                # line 6:
                # squash using Eq.1,
                v_J = squash(s_J)

            elif r_iter < iter_routing - 1:  # Inner iterations, do not apply backpropagation
                s_J = tf.multiply(c_IJ, u_hat_stopped)
                s_J = tf.reduce_sum(s_J, axis=1, keepdims=True)
                v_J = squash(s_J)
                # line 7:
                v_J_tiled = tf.tile(v_J, [1, num_caps_i, 1, 1, 1])
                b_IJ = tf.matmul(u_hat_stopped, v_J_tiled, transpose_a=True)


the result that I made in PyTorch:

def squash(s):


        Squash activations.

        :param s: Signal.

        :return s: Activated signal.


        mag_sq = torch.sum(s**2, dim=2, keepdim=True)

        mag = torch.sqrt(mag_sq)

        s = (mag_sq / (1.0 + mag_sq)) * (s / mag)

        return s

    def routing(self, x, iter_routing):


        :param x: Input features.

        :return : Capsule output.


        batch_size = x.size(0)

        x = x.transpose(1, 2)

        x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)

        W =[self.W] * batch_size, dim=0)

        u_hat = torch.matmul(W, x)

        b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1))

        for _ in range(iter_routing):

            c_ij = torch.nn.functional.softmax(b_ij, dim=2)

            c_ij =[c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)

            v_j = squash(s_j)

            v_j1 =[v_j] * self.in_channels, dim=1)

            u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)

            b_ij = b_ij + u_vj1

            # b_max = torch.max(b_ij, dim = 2, keepdim = True)

            # b_ij = b_ij / b_max.values ## values can be zero so loss would be nan

        return v_j.squeeze(1)

and thank you

Regarding your question about converting TensorFlow code to PyTorch, this is indeed a common challenge in the machine learning community. While re-opening this older thread, I wanted to share a potentially useful tool for those who might still encounter similar issues. I work with ‘Unify, the team behind Ivy’, and I recommend checking out Ivy, an open-source tool, for its unified approach to handling different frameworks. This might simplify the conversion process for you. You can find more information and resources in the Ivy documentation. This could potentially streamline your experience in framework conversion.