After migrating from keras, it runs slower

Hi,

I have a piece of code for CapsuleNet implemented with keras as follows:

def squash(x, axis=-1):
    # s_squared_norm is really small
    # s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
    # scale = K.sqrt(s_squared_norm)/ (0.5 + s_squared_norm)
    # return scale * x
    s_squared_norm = K.sum(K.square(x), axis, keepdims=True)
    scale = K.sqrt(s_squared_norm + K.epsilon())
    return x / scale


# A Capsule Implement with Pure Keras
class Capsule(Layer):
    def __init__(self, num_capsule, dim_capsule, routings=3, kernel_size=(9, 1), share_weights=True,
                 activation='default', **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.kernel_size = kernel_size
        self.share_weights = share_weights
        if activation == 'default':
            self.activation = squash
        else:
            self.activation = Activation(activation)

    def build(self, input_shape):
        super(Capsule, self).build(input_shape)
        input_dim_capsule = input_shape[-1]
        if self.share_weights:
            self.W = self.add_weight(name='capsule_kernel',
                                     shape=(1, input_dim_capsule,
                                            self.num_capsule * self.dim_capsule),
                                     # shape=self.kernel_size,
                                     initializer='glorot_uniform',
                                     trainable=True)
        else:
            input_num_capsule = input_shape[-2]
            self.W = self.add_weight(name='capsule_kernel',
                                     shape=(input_num_capsule,
                                            input_dim_capsule,
                                            self.num_capsule * self.dim_capsule),
                                     initializer='glorot_uniform',
                                     trainable=True)

    def call(self, u_vecs):
        if self.share_weights:
            u_hat_vecs = K.conv1d(u_vecs, self.W)
        else:
            u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])

        batch_size = K.shape(u_vecs)[0]
        input_num_capsule = K.shape(u_vecs)[1]
        u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule,
                                            self.num_capsule, self.dim_capsule))
        u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
        # final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]

        b = K.zeros_like(u_hat_vecs[:, :, :, 0])  # shape = [None, num_capsule, input_num_capsule]
        for i in range(self.routings):
            b = K.permute_dimensions(b, (0, 2, 1))  # shape = [None, input_num_capsule, num_capsule]
            c = K.softmax(b)
            c = K.permute_dimensions(c, (0, 2, 1))
            b = K.permute_dimensions(b, (0, 2, 1))
            outputs = self.activation(K.batch_dot(c, u_hat_vecs, [2, 2]))
            if i < self.routings - 1:
                b = K.batch_dot(outputs, u_hat_vecs, [2, 3])

        return outputs

    def compute_output_shape(self, input_shape):
        return (None, self.num_capsule, self.dim_capsule)


def get_capsule(embedding_matrix):
    features_input = Input(shape=(features.shape[1],))
    z = Dense(4, activation="relu")(features_input)
    inp = Input(shape=(maxlen,))

    x = Embedding(max_features, embedding_matrix.shape[1], weights=[embedding_matrix], trainable=False)(inp)
    x = SpatialDropout1D(rate=0.2)(x)
    x = Bidirectional(CuDNNLSTM(90, return_sequences=True))(x)
    x = Capsule(num_capsule=8, dim_capsule=8, routings=4, share_weights=True)(x)
    x = Flatten()(x)
    x = Dense(32, activation="relu")(x)
    x = concatenate([z, x])
    outp = Dense(1, activation="sigmoid")(x)
    model = Model(inputs=[inp, features_input], outputs=outp)

    # opt = RMSprop(clipnorm=0.03)  # clipnorm=0.03
    model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
    return model

I rewrote it with pytorch:

class GRU_Layer(nn.Module):
    def __init__(self):
        super(GRU_Layer, self).__init__()
        self.gru = nn.LSTM(input_size=300,
                           hidden_size=gru_len,
                           bidirectional=True)

    def init_weights(self):
        ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
        hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
        b = (param.data for name, param in self.named_parameters() if 'bias' in name)
        for k in ih:
            nn.init.xavier_uniform_(k)
        for k in hh:
            nn.init.orthogonal_(k)
        for k in b:
            nn.init.constant_(k, 0)

    def forward(self, x):
        return self.gru(x)


class Caps_Layer(nn.Module):
    def __init__(self, input_dim_capsule=gru_len * 2, num_capsule=2, dim_capsule=2, \
                 routings=2, kernel_size=(9, 1), share_weights=True,
                 activation='default', **kwargs):
        super(Caps_Layer, self).__init__(**kwargs)

        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.kernel_size = kernel_size
        self.share_weights = share_weights
        if activation == 'default':
            self.activation = self.squash
        else:
            self.activation = nn.ReLU(inplace=True)

        if self.share_weights:
            self.W = nn.Parameter(
                nn.init.xavier_uniform_(torch.empty(self.num_capsule * self.dim_capsule, input_dim_capsule, 1)))
        else:
            self.W = nn.Parameter(
                torch.randn(1024, input_dim_capsule, self.num_capsule * self.dim_capsule))

    def forward(self, x):

        if self.share_weights:
            u_hat_vecs = torch.nn.functional.conv1d(x.permute(0, 2, 1), self.W)
        else:
            print('add later')

        batch_size = u_hat_vecs.size(0)
        input_num_capsule = u_hat_vecs.size(2)
        u_hat_vecs = u_hat_vecs.view((batch_size, input_num_capsule,
                                      self.num_capsule, self.dim_capsule))
        u_hat_vecs = u_hat_vecs.permute(0, 2, 1, 3)
        b = torch.zeros_like(u_hat_vecs[:, :, :, 0])

        for i in range(self.routings):
            b = b.permute(0, 2, 1)
            c = torch.nn.functional.softmax(b, dim=2)
            c = c.permute(0, 2, 1)
            b = b.permute(0, 2, 1)
            outputs = self.activation(torch.einsum('bij,bijk->bik', (c, u_hat_vecs)))
            if i < self.routings - 1:
                b = torch.einsum('bik,bijk->bij', (outputs, u_hat_vecs))
        return outputs

    def squash(self, x, axis=-1):
        s_squared_norm = (x ** 2).sum(axis, keepdim=True)
        scale = torch.sqrt(s_squared_norm + T_epsilon)
        return x / scale


class Capsule_Main(nn.Module):
    def __init__(self):
        super(Capsule_Main, self).__init__()
        self.name = 'Capsule_Main'

        self.embedding = nn.Embedding(max_features, embed_size)
        self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
        self.embedding.weight.requires_grad = False

        self.gru_layer = GRU_Layer()
        self.gru_layer.init_weights()
        self.caps_layer = Caps_Layer(num_capsule=8, dim_capsule=8, routings=4)
        self.sp = SpatialDropout1D(p=0.2)

        self.fc_fe = nn.Sequential(
            nn.Linear(7, 4),
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Sequential(
            nn.Linear(64 + 4, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, content, x_fe):
        content1 = self.embedding(content)
        content1 = self.sp(content1)
        content2, (h, c) = self.gru_layer(content1)
        content3 = self.caps_layer(content2)
        content3 = content3.view(content3.size()[0], -1)

        fe_out = self.fc_fe(x_fe)

        conc = torch.cat((content3, fe_out), 1)

        output = self.fc(conc)
        return output

I found that in the case of the same parameters, the pytorch version runs 1.5 times slower than the keras version. At the same time, there is also a loss of 0.02 in loss.

What is wrong with it?

Thanks!

Using torch.norm in squash will be more efficient than spelling it out as you save (slow) memory store/reads.

Best regards

Thomas