Updating complex parameters and defining a real-valued MSE loss

I have the following fully connected architecture and parameters for a network to train on time-series data

class Net(nn.Module):
    def __init__(self, input_dim, output_dim, hl_sizes):
        super(Net, self).__init__()
        current_dim = input_dim
        self.linears = nn.ModuleList()
        for hl_dim in hl_sizes:
            self.linears.append(nn.Linear(current_dim, hl_dim))
            current_dim = hl_dim
        self.L = nn.Parameter(torch.rand(output_dim,dtype=torch.cfloat))
        self.V = nn.Parameter(torch.rand(output_dim,output_dim,dtype=torch.cfloat)) 

    def forward(self, x):
        input_vecs = x
        for layer in self.linears:
            x = F.relu(layer(x))
        x = torch.cat((torch.Tensor(np.ones((x.shape[0],1))),input_vecs,x),dim=1)
        return x

net = Net(
loss_func = nn.MSELoss()

eL = torch.diag_embed(torch.exp(net.L*dt)) 
A = net.V @ eL @ torch.pinverse(net.V) 

X = net(time_series_previous) @ A
Y = net(time_series_next)

loss = loss_func(X,Y)


which works just fine when the parameters L and V are real valued (e.g. torch.float32). However, these parameters need to be complex valued (torch.cfloat). Making this change results in several problems arising:

  1. torch.pinverse is not defined for complex data type. Workaround:
U,s,W = torch.svd(net.V)
pinvV = W @ torch.diag_embed(1/s) @ U.conj()

However, complex matrix multiplication is not defined in pytorch, so we can’t actually straightforwardly compute pinV. Okay so then maybe we can write our own complex matrix multiplication. I did that. It seems to be working. The next problems:

  1. It’s likely that the parameter A will be complex, so X will likely be complex. Why does nn.MSELoss() or loss_func(X,Y) spit out a complex scalar? The norm of a complex number is real-valued! I don’t currently have a workaround for this issue.

  2. Lastly, even if I do end up with a real-valued X, meaning the loss is real-valued, PyTorch gives an error like grad_is_complex == True, expected False (something like this, and the error is specifically when loss.backward() is called).

I would appreciate guidance on how to properly implement a network with learnable complex parameters such as the one above. Thanks!