How to construct model based on my formula in Pytorch

I am totally new to Pytorch and machine learning. I am trying to construct my model from scratch. The model is not CNN or RNN, just based on my formula. The input is two matrixes. What I want to do in my hidden layer is multiplying these two matrixes, and then output the result in the output layer.

class myModel(nn.Module):
    def __init__(self, matrix_a, matrix_b):
        super(myModel, self).__init__()
        self.matrix_a = matrix_a
        self.matrix_b = matrix_b

    def build_model(self):
        self.layers = nn.ModuleList()
        i2h = self.build_input_layer()
        # h2h
        h2h = self.build_hidden_layer()
        # h2o
        h2o = self.build_output_layer()

    def build_input_layer(self):
        input_matrix_a = self.matrix_a
        input_matrix_b = self.matrix_b

    def build_hidden_layer(self):
        # I want to multiply matrix a and b

    def forward(self,matrix_a, matrix_b):
        for layer in self.layers:
            output = layer(matrix_a, matrix_b) # not sure if it is correct
        return output

I’m not sure, how build_model is supposed to work, as build_input_layer does not return anything.
However, if you would just like to pass two matrices, create a single one, and use it as a linear layer, this code should work:

class MyModel(nn.Module):
    def __init__(self, matrix_a, matrix_b):
        super(MyModel, self).__init__()
        self.matrix = nn.Parameter(torch.matmul(matrix_a, matrix_b))
    def forward(self, x):
        x = x.matmul(self.matrix)
        return x
matrix_a = torch.randn(2, 5)
matrix_b = torch.randn(5, 10)
model = MyModel(matrix_a, matrix_b)
x = torch.randn(1, 2)
output = model(x)
loss = output.mean()

Thank you very much!! Yes, you are right. My input layer is wrong. Could you please explain the loss = output.mean() for me? As the output does not have a label, I am not clear why there is a loss function.

I used it just to calculate some valid gradients without a target, so you can just ignore it and use your target and a proper criterion to calculate the loss. :wink: