Complex recurrent layers produce NaN as grad

Hi all. I am trying to run the most basic single-layer RNN with complex inputs on the nightly build of PyTorch (1.10.0-CPU). The problem is that the gradient always evaluates to NaN. I’ve tried all recurrent layers (RNN, GRU, LTSM) with the same result. Here is the model:

class CRNN(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.model = torch.nn.RNN(input_dim, output_dim,
                                  batch_first=True,
                                  dtype=torch.cfloat)
        self.loss = torch.nn.L1Loss()
        self.optimizer = torch.optim.Adam(self.parameters())

    def forward(self, x):
        x = torch.unsqueeze(x, 0)
        return self.model(x)[0]

    def fit(self, x, y):
        self.optimizer.zero_grad()
        z = torch.squeeze(self(x), 0)
        loss = self.loss(z, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()

I used torch.autograd.set_detect_anomaly(True) and it gave the following results:

  • For RNN, the first NaN appears in AddmmBackward
  • For GRU and LSTM, the first NaN appears in L1LossBackward

Since I have no idea how to interpret this, I am left to hope for help from the community.

Are you able to reproduce this issue using random inputs? If so, could you post the random tensor initialization so that we could try to reproduce it, please?

Ok, I have tried running the same model with the following “dataset”:

class RandomDataset(torch.utils.data.IterableDataset):
    def __init__(self):
        super(RandomDataset, self).__init__()

    def __iter__(self):
        return repeat(torch.rand(100, dtype=torch.cfloat))

Results are the same.

Thanks for the update!
I cannot reproduce the issue using:

class CRNN(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.model = torch.nn.RNN(input_dim, output_dim,
                                  batch_first=True,
                                  dtype=torch.cfloat)
        self.loss = torch.nn.L1Loss()
        self.optimizer = torch.optim.Adam(self.parameters())

    def forward(self, x):
        x = torch.unsqueeze(x, 0)
        return self.model(x)[0]

    def fit(self, x, y):
        self.optimizer.zero_grad()
        z = torch.squeeze(self(x), 0)
        loss = self.loss(z, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()

class RandomDataset(torch.utils.data.IterableDataset):
    def __init__(self):
        super(RandomDataset, self).__init__()

    def __iter__(self):
        return repeat(torch.rand(100, dtype=torch.cfloat))
    
model = CRNN(100, 100)
dataset = RandomDataset()
loader = torch.utils.data.DataLoader(dataset)

for i, data in enumerate(loader):
    loss = model.fit(data, data)
    print(i, loss)
    if i > 10000:
        break

Can you please try with GRU/LSTM too?

OK, it seems like I found something interesting.
In your code, try to replace

loader = torch.utils.data.DataLoader(dataset)

with

loader = torch.utils.data.DataLoader(dataset, batch_size=100)

Whether or not the error appears depends on the batch_size. For me, it breaks when batch_size is ≥30.
Looks like exploding gradients to me.

Yes, this could be the case, but note that your code snippet uses the specified batch_size as the sequence length, since you are calling x = torch.unsqueeze(x, 0) and are thus creating a hard-coded batch size of 1 while the batch_size specified in the DataLoader is moved to dim1 and is thus the sequence length.

Yes, this is exactly why I assumed it to be an issue of exploding gradients: the error appears when the sequence length reaches a critical threshold.
Is this expected behavior or a bug?

I would claim it’s expected as usually gradient clipping is applied for long sequences.

The same behavior does not occur with real numbers for the same model, though. Is there a theoretical reason as to why complex numbers are more prone to exploding gradients?
Also, I’ve tested the same setup with LTSM and GRU and there the cutoff line is around 100, meaning they delay the onset of exploding gradients compared to RNN, but do not prevent it. Same with L2-norm gradient clipping.