Generate NN weights using another NN

Hello,

I have the following setup. There is a small NN NetA that predicts values y from a set of inputs x1:

class NetA(nn.Module):
    def __init__(self, nf = 4):
        super(NetA, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(2, nf, bias=False),
            nn.BatchNorm1d(nf),
            nn.Tanh(),

            nn.Linear(nf, nf, bias=False),
            nn.BatchNorm1d(nf),
            nn.Tanh(),

            nn.Linear(nf, nf, bias=False),
            nn.BatchNorm1d(nf),
            nn.Tanh(),

            nn.Linear(nf, 1, bias=False),
        )
    def forward(self, x1):
        x = self.main(x1).view(-1)
        return x

There is also a set of inputs x2 and every x2[i] produces a different set of y[i] for the entire span of x1, used in NetA. Thus, I pre-train NetA for x2[0] to get some kind of base weights and transform NN into a functional form fNetA:

class fNetA(nn.Module):
    def __init__(self):
        super(fNetA, self).__init__()

    def forward(self, x1, bn_param_, param_, param6_, out_param_):
        x = F.tanh(F.batch_norm(F.linear(x1, weight = param_[:, :2], bias = None), running_mean=bn_param_[:,0], running_var=bn_param_[:,1], 
                    weight=bn_param_[:,2], bias=bn_param_[:,3]))
        x = F.tanh(F.batch_norm(F.linear(x, weight = param_[:, 2:], bias = None), running_mean=bn_param_[:,4], running_var=bn_param_[:,5], 
                    weight=bn_param_[:,6], bias=bn_param_[:,7]))
        x = F.tanh(F.batch_norm(F.linear(x, weight = param6_, bias = None), running_mean=bn_param_[:,8], running_var=bn_param_[:,9], 
                    weight=bn_param_[:,10], bias=bn_param_[:,11]))
        x = F.linear(x, weight = out_param_, bias = None)

        return x.view(-1)

I do this because I would like to generate out_param_ with the help of a different NN NetB:

class NetB(nn.Module):
    def __init__(self, nf, gnf):
        super(NetB, self).__init__()
        self.nf = nf
        self.embedding = nn.Embedding(5, gnf)

        self.main = nn.Sequential(
            nn.Linear(gnf, gnf, bias=False),
            nn.BatchNorm1d(gnf),
            nn.Tanh(),

            nn.Linear(gnf, gnf, bias=False),
            nn.BatchNorm1d(gnf),
            nn.Tanh(),

            nn.Linear(gnf, nf, bias=True), # output layer weights out_param_
            nn.Tanh(), # limit delta because y values are very close to baseline
        )

    def forward(self, x2):
        x = self.embedding(x2) # x2 represents specific class
        x = self.main(x).mean(0).view(-1, self.nf)
        return x

The training loop for NetB is:

fNetA.train()
NetB.train()
for i, (x1, x2, y) in enumerate(train_dataloader):
    netb_optimizer.zero_grad()
    new_out_param_ = NetB(x2) + out_param_ # NetB output + baseline weights
    out = fNetA(x1, bn_param_, param_, param6_, new_out_param_) # bn_param_, param_, param6_ are just saved tensors
    g_loss = criterion(out, y)
    loss.backward()
    netb_optimizer.step()

I am aware that I calculate mean over the batch in NetB. I structured batches in a way, so only values for a specific x2[i] can be in one batch and they all have the same new_out_param_.
This setup works if I want to tune the baseline weights of any specific x2[i], meaning that NetB can learn the delta for new_out_param_ for a single case, but when I start using all x2[i] in the train data-loader, NetB cannot converge at all. I tried gradient accumulation over multiple batches to ensure that weights are updated based on multiple x2[i]. I also tried to mix various x2[i] within one batch, produce batched weights by removing mean from NetB and iterate over them accumulating gradients to update the weights. Nothing really works. Removing nnTanh() in the output of NetB doesn’t change anything either.
I noticed that gradients of NetB weights are very small (around 1e-9), but they are that small in a single case too, but NN can still learn. Technically, I can learn all required new_out_param_ independently and then train the generator NN NetB, but I wanted to do it in a proper way, that is scalable. So, I have two questions:

  1. What could be the issue, which is preventing NetB learning in this case?
  2. Is it possible to do it differently? For example, back-propagate loss into the beginning of NetA and use its pure value to train NetB (only using graph from NetB)?

Thanks in advance