Hello there !
First, I am sorry if my post does not fit exactly the autograd category but it seemed to be still a little bit related to it.
Alright I’ll try to explain my issue as clearly as possible !
Long story short : not all the weights in my modules are updated.
The detailed story :
I have a first network returning 2 values : u_lin and v – very classic feedforward static network.
u_lin, v = u_lin_mod(batch_u, batch_x)
u_lin and v are produced by alpha and beta layers.
It is where trouble begins.
So I am dealing with “Neural ODEs” so there is module which is an integration scheme (RK4).
Then v is fed into this RK4 module which is basically a fancy for loop :
z_sim_torch_fit, y_sim_torch_fit = M2_solution(v, batch_z0_hidden)
z0 is the initial condition for integration.
The M2_solution constructor :
def __init__(self, ss_model, ts=1.0, scheme='RK44', device="cpu"): super(RK4Simulator, self).__init__() self.ss_model = ss_model self.ts = ts self.device = device
The forward method :
def forward(self, u_batch, x0_batch): """ Multi-step simulation over (mini)batches Parameters ---------- u_batch: Tensor. Size: (m, q, n_u) Input sequence for each subsequence in the minibatch x0_batch: Tensor. Size: (q, n_x) initial state for each sequence in the minibatch x_batch: Tensor. Size: (m, q, n_x) state sequence for each subsequence in the minibatch Returns ------- Tensor. Size: (m, q, n_x) Simulated state for all subsequences in the minibatch """ X_sim_list =  Y_sim_list: List[torch.Tensor] =  x_step = x0_batch for u_step in u_batch.split(1):#i in range(seq_len): u_step = u_step.squeeze(0) # x_step = x_step.squeeze(0) X_sim_list += [x_step] dt2 = self.ts / 2.0 k1,y_step = self.ss_model(u_step, x_step) Y_sim_list += [y_step] k2_dx, _ = self.ss_model(u_step, x_step + dt2*k1) k3_dx, _ = self.ss_model(u_step, x_step + dt2*k2_dx) k4_dx, _ = self.ss_model(u_step, x_step + self.ts*k3_dx) dx = self.ts / 6.0 * (k1 + 2.0 * k2_dx + 2.0 * k3_dx + k4_dx) x_step = x_step + dx X_sim = torch.stack(X_sim_list, 0) Y_sim = torch.stack(Y_sim_list, 0) return X_sim, Y_sim
So when I do
loss.backward() only the weights of M2_solution are updated and not the ones of u_lin_mod.
However I want the RK4 class to be as generic as possible and I do not want to call it with u_batch and x_batch.
So I tried to change the constructor to this :
def __init__(self, ss_model, u_lin, ts=1.0, scheme='RK44', device="cpu"): super(RK4Simulator, self).__init__() self.ss_model = ss_model self.ts = ts self.device = device self.u_lin_model = u_lin
and in the main loop to call :
u_lin, v = M2_solution.u_lin_model(batch_u, batch_x) z_sim_torch_fit, y_sim_torch_fit = M2_solution(v, batch_z0_hidden)
but it does not work either. Is there a way to work around without writing a lot of different classes. (Because I have 2 other group of modules performing the same operations).
Thanks a lot guys !