How to define a keras-like new layer in pytorch

Hi, everyone. I want to define a new layer to implement point to point multiplication, that is, the input x multiplies the parameters W. The parameters in W need to be learned through training.
In Keras,This can be achieved through the following code:

from keras import backend as K
from keras.engine.topology import Layer
import numpy as np


class iLayer(Layer):
    def __init__(self, **kwargs):
        super(iLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        initial_weight_value = np.random.random(input_shape[1:])
        self.W = K.variable(initial_weight_value)
        self.trainable_weights = [self.W]

    def call(self, x, mask=None):
        print(self.W.shape)
        return x * self.W

    def get_output_shape_for(self, input_shape):
        return input_shape

I want to do this using Pytorch, I have defined a layer like:

class iLayer(nn.Module):
    def __init__(self):
        super(iLayer, self).__init__()
        self.w = nn.Parameter(torch.randn((2, 100, 100))) # suppose the input image is (2*100*100)

    def forward(self, x):
        return x * self.w

I was wondering whether this implementation is wright or not? If not, I sincerely hope someone can help me out. Thanks a lot, everybody.

According to the docs torch.mul does an elementwise multiplication. I’m not sure what the * operation does. So let’s test it…

import torch
import numpy as np

a = torch.randn((2, 100, 100))
b = torch.randn((2, 100, 100))

assert np.all(a * b == torch.mul(a, b))

Your implementation seems good to me, except for the extra underscore in the line…

super(_iLayer, self).__init__()
1 Like

Thanks for your reply. One more question is whether W will be updated along the training process? As in Keras, the W need to be added to self.trainable_weights = [self.W]? I do not know how to do this in PyTorch.

w will be updated in the training process.

The fact that w is an attribute of self to which a Parameter(...) is assigned tells PyTorch that it is a trainable parameter.