Training Neural Event ODE

Hi,

I have been using the torchdiffeq library from the Neural ODE and Neural Event ODE papers and was having trouble training the event ODE. I’m still new to PyTorch, but I’m just trying to get the code working with a simple model: one spiking neuron with the goal of having the network learn the voltage dynamics and an event function corresponding to a spike. However, I can’t figure out how to train the event function and there’s no event training code publically available. I have one loss corresponding to the predicted spike times (event times) vs the actual spike times and one loss corresponding to the voltage trajectory. I am calling backward() and step(), but list(event.parameters())[0].grad is None (where event is the NN for the event function) and list(event.parameters())[0] is not changing between each iteration.

I’ve read that is is usually caused by breaking the graph, but I don’t think I’m doing that anywhere in my code. The GitHub says that both the returned event time and state can be differentiated and gradients will be backpropagated through the event function. My event network is clearly not learning so I’m not sure where my code is wrong. Any help would be greatly appreciated.

class ODEFunc(nn.Module):

    def __init__(self, latent_dim=1, nhidden=1024):
            super(ODEFunc, self).__init__()
            self.elu = nn.ELU(inplace=True)
            self.fc1 = nn.Linear(latent_dim, nhidden)
            self.fc2 = nn.Linear(nhidden, nhidden)
            self.fc3 = nn.Linear(nhidden, latent_dim)
            self.nfe = 0
            self.r = 5

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.elu(out)
        out = self.fc2(out)
        out = self.elu(out)
        out = self.fc3(out)
        return out
        
class ODEEvent(nn.Module):

    def __init__(self, latent_dim=1, nhidden=256):
            super(ODEEvent, self).__init__()
            self.elu = nn.ELU(inplace=True)
            self.fc1 = nn.Linear(latent_dim, nhidden)
            self.fc2 = nn.Linear(nhidden, nhidden)
            self.fc3 = nn.Linear(nhidden, latent_dim)
            self.nfe = 0
            self.r = 5

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.elu(out)
        out = self.fc2(out)
        out = self.elu(out)
        out = self.fc3(out)
        return out

loss_fn = nn.MSELoss(reduction='sum')
func = ODEFunc().to(device).double()
event = ODEEvent().to(device).double()
params = list(func.parameters()) + list(event.parameters())
optimizer = optim.Adam(params, lr=0.001)

for itr in range(30):
    optimizer.zero_grad()
    event_t, state = odeint_event(func, v0, t0, event_fn=event, method='bosh3', atol=1e-6)
    end = int(event_t * 10 + 1)
    tt = t[:end] #slicing time array to solve trajectory up until the first event
    pred_v = odeint(func, v0, tt)
    idx = pred_v.size(dim=0)
    loss1 = loss_fn(pred_v, v[:idx])
    loss2 = loss_fn(event_t, st[0]) #st[0] is the first ground truth spike time
    loss = loss1 + loss2
    loss.backward()
    optimizer.step()
    print(list(event.parameters())[0].grad)