I tried to convert a routing process source code for “Capsule Network-based Model for Learning Node Embeddings” written in TensorFlow to PyTorch. But I don’t know if the conversion is correct?
the code in TensorFlow:
def routing(input, b_IJ, batch_size, iter_routing, num_caps_i, num_caps_j, len_u_i, len_v_j):
# W: [num_caps_i, num_caps_j, len_u_i, len_v_j]
W = tf.get_variable('Weight', shape=(1, num_caps_i, num_caps_j, len_u_i, len_v_j), dtype=tf.float32,
initializer=tf.random_normal_initializer(stddev=0.01, seed=1234))
# Eq.2, calc u_hat
# do tiling for input and W before matmul
input = tf.tile(input, [1, 1, num_caps_j, 1, 1])
W = tf.tile(W, [batch_size, 1, 1, 1, 1])
# in last 2 dims:
u_hat = tf.matmul(W, input, transpose_a=True)
# In forward, u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat
u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')
# line 3,for r iterations do
for r_iter in range(iter_routing):
with tf.variable_scope('iter_' + str(r_iter)):
# line 4:
c_IJ = tf.nn.softmax(b_IJ, axis=1) * num_caps_i #axis=1 # original code
if r_iter == iter_routing - 1:
# line 5:
# weighting u_hat with c_IJ, element-wise in the last two dims
s_J = tf.multiply(c_IJ, u_hat)
# then sum in the second dim, resulting in [batch_size, 1, num_caps_j, len_v_j, 1]
s_J = tf.reduce_sum(s_J, axis=1, keepdims=True)
# line 6:
# squash using Eq.1,
v_J = squash(s_J)
elif r_iter < iter_routing - 1: # Inner iterations, do not apply backpropagation
s_J = tf.multiply(c_IJ, u_hat_stopped)
s_J = tf.reduce_sum(s_J, axis=1, keepdims=True)
v_J = squash(s_J)
# line 7:
v_J_tiled = tf.tile(v_J, [1, num_caps_i, 1, 1, 1])
b_IJ = tf.matmul(u_hat_stopped, v_J_tiled, transpose_a=True)
return(v_J)
the result that I made in PyTorch:
def squash(s):
"""
Squash activations.
:param s: Signal.
:return s: Activated signal.
"""
mag_sq = torch.sum(s**2, dim=2, keepdim=True)
mag = torch.sqrt(mag_sq)
s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
return s
def routing(self, x, iter_routing):
"""
:param x: Input features.
:return : Capsule output.
"""
batch_size = x.size(0)
x = x.transpose(1, 2)
x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)
W = torch.cat([self.W] * batch_size, dim=0)
u_hat = torch.matmul(W, x)
b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1))
for _ in range(iter_routing):
c_ij = torch.nn.functional.softmax(b_ij, dim=2)
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
v_j = squash(s_j)
v_j1 = torch.cat([v_j] * self.in_channels, dim=1)
u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)
b_ij = b_ij + u_vj1
# b_max = torch.max(b_ij, dim = 2, keepdim = True)
# b_ij = b_ij / b_max.values ## values can be zero so loss would be nan
return v_j.squeeze(1)
and thank you