Backpropagating correctly with 3 consecutive nn.Module(s)

Hello there !

First, I am sorry if my post does not fit exactly the autograd category but it seemed to be still a little bit related to it.

Alright I’ll try to explain my issue as clearly as possible !

Long story short : not all the weights in my modules are updated.

The detailed story :

I have a first network returning 2 values : u_lin and v – very classic feedforward static network.

u_lin, v = u_lin_mod(batch_u, batch_x)

u_lin and v are produced by alpha and beta layers.

It is where trouble begins.

So I am dealing with “Neural ODEs” so there is module which is an integration scheme (RK4).

Then v is fed into this RK4 module which is basically a fancy for loop :

z_sim_torch_fit, y_sim_torch_fit = M2_solution(v, batch_z0_hidden)

z0 is the initial condition for integration.

The M2_solution constructor :

    def __init__(self, ss_model, ts=1.0, scheme='RK44', device="cpu"):
        super(RK4Simulator, self).__init__()
        self.ss_model = ss_model
        self.ts = ts
        self.device = device

The forward method :

def forward(self, u_batch, x0_batch):
        """ Multi-step simulation over (mini)batches

        Parameters
        ----------
        u_batch: Tensor. Size: (m, q, n_u)
            Input sequence for each subsequence in the minibatch

        x0_batch: Tensor. Size: (q, n_x)
            initial state for each sequence in the minibatch
        x_batch: Tensor. Size: (m, q, n_x)
            state sequence for each subsequence in the minibatch
        Returns
        -------
        Tensor. Size: (m, q, n_x)
            Simulated state for all subsequences in the minibatch

        """

      
        X_sim_list = []
        Y_sim_list: List[torch.Tensor] = []
        x_step = x0_batch

        for u_step in u_batch.split(1):#i in range(seq_len):

            u_step = u_step.squeeze(0)
            # x_step = x_step.squeeze(0)
           
            X_sim_list += [x_step]

            dt2 = self.ts / 2.0
            k1,y_step = self.ss_model(u_step, x_step)
            Y_sim_list += [y_step]


            k2_dx, _ = self.ss_model(u_step, x_step + dt2*k1)
            k3_dx, _  = self.ss_model(u_step, x_step + dt2*k2_dx)
            k4_dx, _ = self.ss_model(u_step, x_step + self.ts*k3_dx)

            dx = self.ts / 6.0 * (k1 + 2.0 * k2_dx + 2.0 * k3_dx + k4_dx)
            x_step = x_step + dx


        X_sim = torch.stack(X_sim_list, 0) 
        Y_sim = torch.stack(Y_sim_list, 0)
        return X_sim, Y_sim

So when I do loss.backward() only the weights of M2_solution are updated and not the ones of u_lin_mod.

However I want the RK4 class to be as generic as possible and I do not want to call it with u_batch and x_batch.

So I tried to change the constructor to this :

    def __init__(self, ss_model, u_lin, ts=1.0, scheme='RK44', device="cpu"):
        super(RK4Simulator, self).__init__()
        self.ss_model = ss_model
        self.ts = ts
        self.device = device
        self.u_lin_model = u_lin

and in the main loop to call :

u_lin, v = M2_solution.u_lin_model(batch_u, batch_x)
z_sim_torch_fit, y_sim_torch_fit = M2_solution(v, batch_z0_hidden)

but it does not work either. Is there a way to work around without writing a lot of different classes. (Because I have 2 other group of modules performing the same operations).

Thanks a lot guys !