Hello there !
First, I am sorry if my post does not fit exactly the autograd category but it seemed to be still a little bit related to it.
Alright I’ll try to explain my issue as clearly as possible !
Long story short : not all the weights in my modules are updated.
The detailed story :
I have a first network returning 2 values : u_lin and v – very classic feedforward static network.
u_lin, v = u_lin_mod(batch_u, batch_x)
u_lin and v are produced by alpha and beta layers.
It is where trouble begins.
So I am dealing with “Neural ODEs” so there is module which is an integration scheme (RK4).
Then v is fed into this RK4 module which is basically a fancy for loop :
z_sim_torch_fit, y_sim_torch_fit = M2_solution(v, batch_z0_hidden)
z0 is the initial condition for integration.
The M2_solution constructor :
def __init__(self, ss_model, ts=1.0, scheme='RK44', device="cpu"):
super(RK4Simulator, self).__init__()
self.ss_model = ss_model
self.ts = ts
self.device = device
The forward method :
def forward(self, u_batch, x0_batch):
""" Multi-step simulation over (mini)batches
Parameters
----------
u_batch: Tensor. Size: (m, q, n_u)
Input sequence for each subsequence in the minibatch
x0_batch: Tensor. Size: (q, n_x)
initial state for each sequence in the minibatch
x_batch: Tensor. Size: (m, q, n_x)
state sequence for each subsequence in the minibatch
Returns
-------
Tensor. Size: (m, q, n_x)
Simulated state for all subsequences in the minibatch
"""
X_sim_list = []
Y_sim_list: List[torch.Tensor] = []
x_step = x0_batch
for u_step in u_batch.split(1):#i in range(seq_len):
u_step = u_step.squeeze(0)
# x_step = x_step.squeeze(0)
X_sim_list += [x_step]
dt2 = self.ts / 2.0
k1,y_step = self.ss_model(u_step, x_step)
Y_sim_list += [y_step]
k2_dx, _ = self.ss_model(u_step, x_step + dt2*k1)
k3_dx, _ = self.ss_model(u_step, x_step + dt2*k2_dx)
k4_dx, _ = self.ss_model(u_step, x_step + self.ts*k3_dx)
dx = self.ts / 6.0 * (k1 + 2.0 * k2_dx + 2.0 * k3_dx + k4_dx)
x_step = x_step + dx
X_sim = torch.stack(X_sim_list, 0)
Y_sim = torch.stack(Y_sim_list, 0)
return X_sim, Y_sim
So when I do loss.backward()
only the weights of M2_solution are updated and not the ones of u_lin_mod.
However I want the RK4 class to be as generic as possible and I do not want to call it with u_batch and x_batch.
So I tried to change the constructor to this :
def __init__(self, ss_model, u_lin, ts=1.0, scheme='RK44', device="cpu"):
super(RK4Simulator, self).__init__()
self.ss_model = ss_model
self.ts = ts
self.device = device
self.u_lin_model = u_lin
and in the main loop to call :
u_lin, v = M2_solution.u_lin_model(batch_u, batch_x)
z_sim_torch_fit, y_sim_torch_fit = M2_solution(v, batch_z0_hidden)
but it does not work either. Is there a way to work around without writing a lot of different classes. (Because I have 2 other group of modules performing the same operations).
Thanks a lot guys !