I have the following fully connected architecture and parameters for a network to train on time-series data
class Net(nn.Module):
def __init__(self, input_dim, output_dim, hl_sizes):
super(Net, self).__init__()
current_dim = input_dim
self.linears = nn.ModuleList()
for hl_dim in hl_sizes:
self.linears.append(nn.Linear(current_dim, hl_dim))
current_dim = hl_dim
self.L = nn.Parameter(torch.rand(output_dim,dtype=torch.cfloat))
self.V = nn.Parameter(torch.rand(output_dim,output_dim,dtype=torch.cfloat))
def forward(self, x):
input_vecs = x
for layer in self.linears:
x = F.relu(layer(x))
x = torch.cat((torch.Tensor(np.ones((x.shape[0],1))),input_vecs,x),dim=1)
return x
net = Net(
loss_func = nn.MSELoss()
eL = torch.diag_embed(torch.exp(net.L*dt))
A = net.V @ eL @ torch.pinverse(net.V)
X = net(time_series_previous) @ A
Y = net(time_series_next)
loss = loss_func(X,Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
which works just fine when the parameters L and V are real valued (e.g. torch.float32). However, these parameters need to be complex valued (torch.cfloat). Making this change results in several problems arising:
- torch.pinverse is not defined for complex data type. Workaround:
U,s,W = torch.svd(net.V)
pinvV = W @ torch.diag_embed(1/s) @ U.conj()
However, complex matrix multiplication is not defined in pytorch, so we can’t actually straightforwardly compute pinV. Okay so then maybe we can write our own complex matrix multiplication. I did that. It seems to be working. The next problems:
-
It’s likely that the parameter A will be complex, so X will likely be complex. Why does nn.MSELoss() or loss_func(X,Y) spit out a complex scalar? The norm of a complex number is real-valued! I don’t currently have a workaround for this issue.
-
Lastly, even if I do end up with a real-valued X, meaning the loss is real-valued, PyTorch gives an error like grad_is_complex == True, expected False (something like this, and the error is specifically when loss.backward() is called).
I would appreciate guidance on how to properly implement a network with learnable complex parameters such as the one above. Thanks!