Complex Autograd failing


I am currently implementing complex GRU-cells as described in the paper “Complex Gated Recurrent Neural Networks”.
Working on the currently nightly (1.8.0dev20210201) and the list of supported complex gradients, I expected the code below to run through.
Instead, I receive the following error:

  File "...\model\spectral_rnn\", line 75, in train
  File "...\Python38\lib\site-packages\torch\", line 227, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "...\Python38\lib\site-packages\torch\autograd\", line 145, in backward
RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false.

I also reported it as a bug as issue in the repository.
I am happy for any hint, e.g. which operation the problem could be!


def to_complex_activation(activation):
    return lambda x: torch.view_as_complex(
        [activation(x.real).unsqueeze(-1), activation(x.imag).unsqueeze(-1)], dim=-1))

class CGCell(nn.Module):

    def __init__(self, input_size, hidden_size):
        super(CGCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.wg = nn.Parameter(torch.randn(2 * hidden_size, hidden_size, dtype=torch.cfloat)) = nn.Parameter(torch.randn(2 * hidden_size, input_size, dtype=torch.cfloat)) = nn.Parameter(torch.randn(2 * hidden_size, dtype=torch.cfloat))

        self.w = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=torch.cfloat))
        self.v = nn.Parameter(torch.randn(hidden_size, input_size, dtype=torch.cfloat))
        self.b = nn.Parameter(torch.randn(hidden_size, dtype=torch.cfloat))

        alpha = beta = 0.5  # TODO
        self.fg = lambda x: torch.sigmoid(alpha * x.real + beta * x.imag)
        self.fa = to_complex_activation(torch.sigmoid)

    def _init_hidden(self, x):
        h = torch.zeros((x.shape[0], self.hidden_size),  dtype=torch.cfloat).to(device)

        return h

    def forward(self, x, ht_=None):
        if ht_ is None:
            ht_ = self._init_hidden(x)

        gates = ht_ @ self.wg.T + x @ +
        g_r, g_z = gates.chunk(2, 1)

        g_r = self.fg(g_r)
        g_z = self.fg(g_z)

        z = (g_r * ht_) @ self.w.T + x @ self.v.T + self.b
        ht = g_z * self.fa(z) + (1 - g_z) * ht_

        return ht

Bug resolved, see this issue.