How to use intermediate variable as weights for later nn.Linear layers?

Hi, I’m having problem training network of the following structure (structure inspired by a recent paper):

Given input X, Y, we want to first find a hidden variable based on X:
X → fc1 → fc2 → hidden

Then we want to slice the hidden variable so that we can assign them as weights/bias to fc3, fc4
And then we compute:
Y → fc3 → fc4 → prediction

We have a ground truth to compute loss with prediction.

However, during training the error never decreases. I’m wondering if the error may come from inplace value assignment for fc3, fc4? My current setup is attached below.

class Surface_NN(torch.nn.Module):
  def __init__(self):
    super(Surface_NN, self).__init__()
    self.fc1 = nn.Linear(3, 25)
    self.fc2 = nn.Linear(25, 303)
    
    self.fc3 = nn.Linear(2, 50) # 3 * 50 = 150 params
    self.fc4 = nn.Linear(50, 3) # 51 * 3 = 153 params, in total 303 params

  def forward(self, Image, Surface):
    hidden = self.fc1(Image)
    hidden = F.relu(hidden)
    hidden = self.fc2(hidden)

    # NOTE THIS SPECIFIC SETUP ONLY APPLY TO BATCH_SIZE = 1.
    weights = hidden[0]

    fc3weight = (weights[0:100]).reshape([1, 50,2])
    fc3bias = weights[100:150]
    fc4weight = (weights[150:300]).reshape([1, 3, 50])
    fc4bias = weights[300:304]

    self.state_dict()['fc3.weight'] = fc3weight
    self.state_dict()['fc3.bias'] = fc3bias
    self.state_dict()['fc4.weight'] = fc4weight
    self.state_dict()['fc4.bias'] = fc4bias
 
    middle = F.relu(self.fc3(Surface))
    result = self.fc4(middle)
    return result

Assigning the tensors to the state_dict keys won’t change the actual parameters, if you don’t reload the state_dict.
Based on the description of the use case, I think you could use the functional API and directly use the fc3 parameters:

fc3weight = (weights[0:100]).reshape([1, 50,2])
fc3bias = weights[100:150]
...
out = F.linear(Surface, fc3weight, fc3bias)
...

Thanks a lot this worked! (Although I need to reshape the weight/bias to 2D instead of 3D)