Custom layer from keras to pytorch

Coming from TensorFlow background, I am trying to convert a snippet of code of the custom layer from Keras to pytorch.

The custom layer in Keras looks like this:

class Attention_module(tf.keras.layers.Layer):
    def __init__(self, class_num):
        super(Attention_module self).__init__(class_num)
        self.class_num = class_num
        self.Ws = None

    def build(self, input_shape):
        embedding_length = int(input_shape[2])

        self.Ws = self.add_weight(shape=(self.class_num, embedding_length),
                                  initializer=tf.keras.initializers.get('glorot_uniform'), trainable=True)

        super(Attention_module, self).build(input_shape)

    def call(self, inputs):

        sentence_trans = tf.transpose(inputs, [0, 2, 1])
        at = tf.matmul(self.Ws, sentence_trans)
        at = tf.math.tanh(at)
        at = K.exp(at - K.max(at, axis=-1, keepdims=True))
        at = at / K.sum(at, axis=-1, keepdims=True)
        v = K.batch_dot(at, inputs)

        return v

I want to implement the same in the torch; I have already done the forward pass block but am confused about how to do the embedding and weight initialization the same as the above layer in pytorch?

class Attention_module(torch.nn.Module):
    def __init__(self, class_num):

        # how to initialize weight with same as above keras layer?

    def forward(self, inputs):
        sentence_trans = inputs.permute(0, 2, 1)

        at =, sentence_trans)
        at = torch.nn.Tanh(at)
        at = torch.exp(at - torch.max(torch.Tensor(at), dim=-1, keepdims=True).values)
        at = at / torch.sum(at, dim = -1, keepdims=True)
        v = torch.einsum('ijk,ikl->ijl', at, inputs)

        return v

Thank you!

There are several aspects:

  • the parameter is created by assigning self.Ws as a nn.Parameter of the appropriate shape in __init__ (after super().__init__()).
  • the equivalent initialization is torch.nn.init.xavier_uniform_, but you may have to adjust the gain parameter. Check the mean and std (and a histogram plot)
  • you have some funny stuff in the forward mixing torch.nn things where I would have expected torch.nn.functional, also you should never need to use torch.Tensor,
  • finally, are you spelling out torch.nn.functional.softmax in the lines with torch.exp and the following? I’d probably write it as softmax if you are. I’m not entirely sure why you would to tanh and then softmax, but hey, what do I know of your problem.

I didn’t look at the dimension ordering (the permute and others).

Best regards



Hi, Thank you for your detailed reply.

I am more confused about self.add_weight function from keras.
is it equivalent to nn.Parameter?
An example would help a lot.

The PyTorch way to track trainable parameters is this, take e.g. the Linear module:

Empty is an uninitialized tensor, then it is initialized (albeit with funny gain parameters, so you don’t want to use these).

1 Like

Why you left Tensorflow?