The thing is that I am trying to implement a Neural ODE (NODE) in Pytorch. I’m following this example: https://github.com/DiffEqML/torchdyn/blob/master/tutorials/02_classification.ipynb
If you know a simple implementantion of NODE, please share it with me!
Regarding the error:
f = nn.Sequential(
nn.Conv1d(3,64,kernel_size=2),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Flatten(start_dim=1),
nn.Linear(576,50),
nn.ReLU(),
nn.Linear(50, ph)
)
model = NeuralDE(f,
solver='rk4',
sensitivity='autograd',
s_span=torch.linspace(0, 1, 10))
learningRate = 0.01
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate)
for epoch in range(180): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 200 == 199:
#if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
#(epoch + 1, i + 1, running_loss / 2000))
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-15bbd244efe3> in <module>()
10
11 # forward + backward + optimize
---> 12 outputs = model(inputs)
13 loss = criterion(outputs, labels)
14 loss.backward()
6 frames
/usr/local/lib/python3.6/dist-packages/torchdiffeq/_impl/rk_common.py in rk4_alt_step_func(func, t, dt, y, k1)
98 if k1 is None:
99 k1 = func(t, y)
--> 100 k2 = func(t + dt * _one_third, y + dt * k1 * _one_third)
101 k3 = func(t + dt * _two_thirds, y + dt * (k2 - k1 * _one_third))
102 k4 = func(t + dt, y + dt * (k1 - k2 + k3))
RuntimeError: The size of tensor a (3) must match the size of tensor b (200) at non-singleton dimension 1
If you know any simple implementation of NODE, please share it. Thank you!