Differentiate through complex differential equation

I’m currently working an a problem that is related to a differential equation, in which a complex-valued term appears, so the solution to the differential equation will also be complex-valued. The equation looks like this:

y’’(t) + (A(t) + j*B(t))**2 * y(t) = 0

… with j being the complex unit. I want my neural network to predict the coefficient B so that my solution y satisfies some specific conditions. For that I would like to solve the differential equation with the predicted B, check the Loss of y and then backpropagate through the DE-Solver to my model. I think this should generally work, but I don’t know how to deal with the complex values, as PyTorch doesn’t have complex numbers implemented yet. Maybe somebody has already worked on similar problems and can help with an advice?

Hi @fewagner PyTorch recently added complex numbers support so you can make tensors of complex dtype:
torch.randn(4, 4, dtype=torch.float) and perform operations on them. you can also compute compute complex derivatives, i.e. run backward on loss functions with complex valued parameters. for more reading on complex autograd, check out: https://pytorch.org/docs/stable/notes/autograd.html?highlight=complex%20autograd