Defining loss functions based on a transformed version of a neural network output

Hi guys, I’m new to Pytorch and have been working on creating a pair of neural networks for function approximation. More specifically, the two networks F,G should independently approximate functions f,g.
i.e. F(x) ~ f(x) (1 input, 1 output) , G(x,y)~g(x,y) (2 inputs, 1 output)

However, the system I’m trying to learn is a predator-prey dynamics where the output (future position) is a linear combination of a f,g applied to the inputs.

i.e.

x_out = af(x_in) + bg(x_in,y_in)
y_out = cf(y_in) + dg(y_in, x_in)

where weights a,b,c,d are known.

For simple prediction, I know I can run it as an RNN, but in this case I’m more interested in identifying the component functions f(x), g(x,y) of the system.

Currently, I’ve declared two neural networks F,G that are trained simultaneously on the same data. But due to the the constants a,b,c,d does PyTorch allow you to multiple or modify the network outputs before calculating loss and running the optimiser step? Would it be better to setup a single neural network with toggleable network connections to switch between F and G (if that is even possible)?

I’ve attached the code here with some comments to hopefully act as a guide. Any help would be greatly appreciated!

Neural network setup and generating data:

# %% HyperParameters
device = torch.device("cuda")
batch_size = 32
max_epochs = 3
learning_rate = 1e-5

# %% Generate data
a,b,c,d = 1.1,0.4,0.1,0.4
dt = 0.001
steps = 200000

dots = np.empty((2,steps), float)
pos = np.empty((2,steps), float)
pos[0,0],pos[1,0] = 1,1
for i in range(steps):
    dots[0,i] = a*pos[0,i]-b*pos[0,i]*pos[1,i]
    dots[1,i] = c*pos[0,i]*pos[1,i]-d*pos[1,i]

    if i < steps-1:
        for j in range(2):
            pos[j,i+1] = pos[j,i] + dt*dots[j,i]

# %% Define Model
class NN(nn.Module):
    def __init__(self, input_size, output_size):
        super(NN, self).__init__()

        self.linear1 = nn.Linear(input_size,32)
        self.linear2 = nn.Linear(32,32)
        self.linear3 = nn.Linear(32,32)
        self.linear4 = nn.Linear(32,32)
        self.linear9 = nn.Linear(32,output_size)

    def forward(self, input):
        output = F.relu(self.linear1(input))
        output = F.relu(self.linear2(output))
        output = F.relu(self.linear3(output))
        output = F.relu(self.linear4(output))
        output = self.linear9(output)
        return output

# %% Input Dataset Class
class NumericalDataset(Dataset):
    def __init__(self, X,Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return (self.X).shape[1]

    def __getitem__(self, idx):
        X = torch.Tensor([self.X[:,idx]])
        Y = torch.Tensor([self.Y[:,idx]])
        return X,Y

# %% Input Datasets and load
val_ratio = 0.2
test_ratio = 0.2

val_length = int(steps*val_ratio)
test_length = int(steps*test_ratio)
train_length = steps-(val_length+test_length)

train_data = pos[:,:train_length]
val_data = pos[:,train_length:train_length+val_length]
test_data = pos[:,train_length+val_length:]

train_in = train_data[:,:-1]
val_in = val_data[:,:-1]
test_in = test_data[:,:-1]

# and labels
train_out = train_data[:,1:]
val_out = val_data[:,1:]
test_out = test_data[:,1:]

train = NumericalDataset(train_in,train_out)
train_dataloader = DataLoader(train, batch_size = batch_size, shuffle = True)
val = NumericalDataset(val_in, val_out)
val_dataloader = DataLoader(val, batch_size = batch_size, shuffle = True)
test = NumericalDataset(test_in, test_out)
test_dataloader = DataLoader(test, batch_size = batch_size, shuffle = True)

Training and optimising:

# %% Loss and optimisation
modelF = NN(1,1).to(device)
optimF = torch.optim.Adam(modelF.parameters(), lr=learning_rate, weight_decay = 0.01)

modelG = NN(2,1).to(device)
optimG = torch.optim.Adam(modelG.parameters(), lr=learning_rate, weight_decay = 0.01)

loss_fn = torch.nn.MSELoss(reduction = "mean")

# %% Model Training
epoch_train_F = []
epoch_train_G = []
epoch_val = []
for epoch in range(max_epochs):
    #Training Phase
    F_losses = []
    G_losses = []
    for i, (X,Y) in enumerate(train_dataloader):

        X = X.to(device = device)
        Y = Y.to(device = device)

        #Train F on X
        modelF.train()
        modelG.eval()
        optimF.zero_grad()
        optimG.zero_grad()

        #Forward and error loss calculation
        F_X1 = modelF(X[0])
        F_X2 = modelF(X[1])
        G_X1X2 = modelG(X)

        Y1 = torch.subtract(torch.mul(a,F_X1), torch.mul(b,G_X1X2))
        Y2 = torch.subtract(torch.mul(c,G_X1X2), torch.mul(d,F_X2))
        y_pred = torch.stack((Y1,Y2))
        loss = loss_fn(input = y_pred,target = Y)
        F_losses.append(loss.item())

        #Backward
        loss.backward()

        #Gradient descent and update weights
        optimF.step()

        #Train G on X
        modelF.eval()
        modelG.train()
        optimF.zero_grad() #No storing backprops to waste memory
        optimG.zero_grad() #No storing backprops to waste memory

        #Forward and error loss calculation
        F_X1 = modelF(X[0])
        F_X2 = modelF(X[1])
        G_X1X2 = modelG(X)

        Y1 = torch.subtract(torch.mul(a,F_X1), torch.mul(b,G_X1X2))
        Y2 = torch.subtract(torch.mul(c,G_X1X2), torch.mul(d,F_X2))
        y_pred = torch.stack((Y1,Y2))
        loss = loss_fn(input = y_pred,target = Y)
        G_losses.append(loss.item())

        #Backward
        loss.backward()

        #Gradient descent and update weights
        optimG.step()

        if i%100==0:
            print(f'Train Epochs: [{epoch}/{max_epochs}], Iter: {i}, F_losses: {np.mean(F_losses)}, G_losses: {np.mean(G_losses)}')

    epoch_train_F.append(np.mean(F_losses))
    epoch_train_G.append(np.mean(G_losses))

    #Validation Phase
    modelF.eval()
    modelF.eval()

    val_losses = []
    for i, (X,Y) in enumerate(val_dataloader):
        X = X.to(device = device)
        Y = Y.to(device = device)

        optimF.zero_grad() #No storing backprops to waste memory
        optimG.zero_grad() #No storing backprops to waste memory

        #Forward and error loss calculation
        F_X1 = modelF(X[0])
        F_X2 = modelF(X[1])
        G_X1X2 = modelG(X)

        Y1 = torch.subtract(torch.mul(a,F_X1), torch.mul(b,G_X1X2))
        Y2 = torch.subtract(torch.mul(c,G_X1X2), torch.mul(d,F_X2))
        y_pred = torch.stack((Y1,Y2))

        loss = loss_fn(input = y_pred,target = Y)
        val_losses.append(loss.item())

        if i%20==0:
            print(f'Val Epochs: [{epoch}/{max_epochs}], Iter: {i}, Loss: {np.mean(val_losses)}')

    epoch_val.append(np.mean(val_losses))
    #
    # if epoch%5 ==0:
    #     print("epoch %d / %d" % (epoch+1, max_epochs))
    #     print(f'Epoch: {epoch}, Training Loss: {epoch_train[-1]}, Validation Loss: {epoch_val[-1]}')

    if epoch>4:
        A = np.array(epoch_val[-5:-1])
        B = np.array(epoch_val[-4:])
        diff = A-B
        if all(values < 0 for values in diff):
            break