Convert tensorflow source code to pytorch source code

I tried to convert a routing process source code for “Capsule Network-based Model for Learning Node Embeddings” written in TensorFlow to PyTorch. But I don’t know if the conversion is correct?
the code in TensorFlow:

def routing(input, b_IJ, batch_size, iter_routing, num_caps_i, num_caps_j, len_u_i, len_v_j):
    # W: [num_caps_i, num_caps_j, len_u_i, len_v_j]
    W = tf.get_variable('Weight', shape=(1, num_caps_i, num_caps_j, len_u_i, len_v_j), dtype=tf.float32,
                        initializer=tf.random_normal_initializer(stddev=0.01, seed=1234))

    # Eq.2, calc u_hat
    # do tiling for input and W before matmul
    input = tf.tile(input, [1, 1, num_caps_j, 1, 1])
    W = tf.tile(W, [batch_size, 1, 1, 1, 1])

    # in last 2 dims:
    u_hat = tf.matmul(W, input, transpose_a=True)
    # In forward, u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat
    u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')

    # line 3,for r iterations do
    for r_iter in range(iter_routing):
        with tf.variable_scope('iter_' + str(r_iter)):
            # line 4:
            c_IJ = tf.nn.softmax(b_IJ, axis=1) * num_caps_i #axis=1 # original code

            if r_iter == iter_routing - 1:
                # line 5:
                # weighting u_hat with c_IJ, element-wise in the last two dims
                s_J = tf.multiply(c_IJ, u_hat)
                # then sum in the second dim, resulting in [batch_size, 1, num_caps_j, len_v_j, 1]
                s_J = tf.reduce_sum(s_J, axis=1, keepdims=True)
                # line 6:
                # squash using Eq.1,
                v_J = squash(s_J)

            elif r_iter < iter_routing - 1:  # Inner iterations, do not apply backpropagation
                s_J = tf.multiply(c_IJ, u_hat_stopped)
                s_J = tf.reduce_sum(s_J, axis=1, keepdims=True)
                v_J = squash(s_J)
                # line 7:
                v_J_tiled = tf.tile(v_J, [1, num_caps_i, 1, 1, 1])
                b_IJ = tf.matmul(u_hat_stopped, v_J_tiled, transpose_a=True)

    return(v_J)

the result that I made in PyTorch:

def squash(s):

        """

        Squash activations.

        :param s: Signal.

        :return s: Activated signal.

        """

        mag_sq = torch.sum(s**2, dim=2, keepdim=True)

        mag = torch.sqrt(mag_sq)

        s = (mag_sq / (1.0 + mag_sq)) * (s / mag)

        return s

    def routing(self, x, iter_routing):

        """

        :param x: Input features.

        :return : Capsule output.

        """

        batch_size = x.size(0)

        x = x.transpose(1, 2)

        x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)

        u_hat = torch.matmul(W, x)

        b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1))


        for _ in range(iter_routing):

            c_ij = torch.nn.functional.softmax(b_ij, dim=2)

            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)

            v_j = squash(s_j)

            v_j1 = torch.cat([v_j] * self.in_channels, dim=1)

            u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)

            b_ij = b_ij + u_vj1

            # b_max = torch.max(b_ij, dim = 2, keepdim = True)

            # b_ij = b_ij / b_max.values ## values can be zero so loss would be nan

        return v_j.squeeze(1)

and thank you