How do you add the input-layer of a neural network as trainable parameters?

Hi,

First, a quick background.
I’m currently trying to use a simple neural network to automate the design of an electronic circuit. At the moment I have a pre-trained model, which predicts circuit performance (outputs) based on component values (inputs). Given some target outputs, I want to find the optimal input values that yields those specific targets.

My goal is to add a random set of inputs as trainable parameters to my pretrained model, such that I can use gradient descent to automatically estimate the optimal inputs.

My pre-trained model looks like this:

class NeuralNetwork(nn.Module):
    def __init__(self, in_dim, out_dim, n_hid1, n_hid2, p=0):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_stack = nn.Sequential(
            nn.Linear(in_dim, n_hid1),
            nn.ReLU(),
			nn.Dropout(p),
            nn.Linear(n_hid1, n_hid2),
            nn.ReLU(),
			nn.Dropout(p),
            nn.Linear(n_hid2, out_dim),
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_stack(x)
        return x

I’ve loaded a pretrained NeuralNetwork class and frozen its parameters using:

for param in model.parameters():
    param.requires_grad = False

I’ve tried to add the inputs as trainable parameters using nn.torch.Parameters() in a new module, which is where I run into problems. Firstly, I’ve defined a new module, where I’ve just added a randomly initiated Tensor with shape (number of data points, input dimensions).

class InputTensor(nn.Module):
    def __init__(self, input_dimension, num_data):
        super(InputTensor, self).__init__()
        self.flatten = nn.Flatten()
        self.parameter = nn.Parameter(torch.rand((num_data, input_dimension)), requires_grad=True)

            
    def forward(self, x):
        x = self.flatten(x)
        x = self.parameter(x)
        return x

I’ve then combined the pretrained model with the new input tensor using torch.nn.Sequential, such that the optimal inputs for a given set of outputs can be found automatically using gradient descent.

input_tensor = InputTensor(in_dim, num_datapoints)
new_model = nn.Sequential(input_tensor, model)

However, when I run my training loop, I get an error stating that torch.nn.Parameter() object is not callable. I understand the error, but I am confused as to how I properly add the input-layer as as trainable parameters.

FYI, this is the error I get:

      9     def forward(self, x):
     10         x = self.flatten(x)
---> 11         x = self.parameter(x)
     12         return x
     13 

TypeError: 'Parameter' object is not callable

Regards,
Rasmus

What is the intended function of this line?
nn.Parameter is like a buffer/tensor variable that holds the network parameters.
It cannot be used as a function call.

I see the problem with that statement. All I want to do is to generate a Tensor with random values from this call, and pass it to my pretrained model:

self.parameter = nn.Parameter(torch.rand((num_data, input_dimension)), requires_grad=True)

Hence, I suppose I don’t need the forward pass in my InputTensor class to take an argument x, as they are already defined as parameters.

But once I have generated the randomly initiated Tensor mentioned above, how do I pass it to my pretrained model as inputs?

You do not need to create a separate class for input.
You can simply create a tensor and pass it as below:

data = torch.rand(num_data, input_dimension, requires_grad=True)
out = model(data)

Yes, but would I then have to pass out.parameters() to my optimizer?

In my final model I would only need to pass the targets, as the inputs should be integrated in the model as learnable parameters.

You could do something like below:

import torch
import torch.nn as nn
import torch.optim as optim


class Model(nn.Module):
    def __init__(self, inch, outch):
        super().__init__()
        self.layer = nn.Sequential(nn.Linear(inch, outch),
                                    nn.ReLU(),
                                    nn.Linear(outch, 2))

    def forward(self, x):
        return self.layer(x)


if __name__ == '__main__':
    data = nn.Parameter(torch.randn(2,4).cuda())
    opt = optim.SGD([data], lr=1e-1, momentum=0.9)

    model = Model(4, 8).cuda()
    print(f"input at start: {data}")

    totalepochs = 200
    for epoch in range(totalepochs):
        opt.zero_grad()
        out = model(data)
        out = out.softmax(dim=-1)

        # simple loss function to reduce the probability of second class
        loss = (out[:,1] * 20).sum()
        loss.backward()
        opt.step()

        if epoch % 50 == 1:
            print(f"{epoch}/{totalepochs}: out {out}")

    print(f"input at end: {data}")


1 Like