Custom LSTM returns nan

Hi,i implemented my own custom LSTMCell based on [pytorch/benchmarks/fastrnns/custom_lstms.py at main · pytorch/pytorch · GitHub],
but during back-propagation i get nan values (after two or three iterations).To be more specific my net is consisted of CNN (Alexnet) + CustomRNN + Log_Softmax and is trained with CTC loss.As far as my custom LSTM is concerned, it is an implementation of Differential RNN https://arxiv.org/abs/1504.06678. Below are some snippets of my code.
LSTMCell:

Model forward:
def forward(self, x):
        LSTMState = namedtuple('LSTMState', ['hx', 'cx', 'dc'])

        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.cnn(c_in)
        c_out = c_out.view(-1, batch_size, 4096)

        h1 = torch.zeros(batch_size, self.hidden_size).cuda(0)
        h2 = torch.zeros(batch_size, self.hidden_size).cuda(0)
        h3 = torch.zeros(batch_size, self.hidden_size).cuda(0)

        states = [[LSTMState(h1, h2, h3)
                   for _ in range(2)]
                  for _ in range(self.num_layers)]
        r_out, out_state = self.rnn(c_out, states)
        custom_state = double_flatten_states(out_state)
        r_out2 = self.last_linear(r_out)
        return (r_out2)

Thanks in advance

1 Like

I guess you should also include some of your training code to help troubleshoot. The code you’ve provided here looks ok.

Given that it happens after a few epochs I guess the gradient is either vanishing or exploding. Either one could be caused by a learning rate issue. For exploding gradients you could try gradient clipping.

Hi Jack and thanks for your reply.First of all what i meant was that i got nan after two or three batches, not epochs(reducing learning rate wouldn’t work).What i modified to the original code is that i added one state (the derivative of cell state, dc) and the computation of gates of
course, as you can see above, and consequently i can’t understand why it should return nan (i can run the original code with my model and training code without a problem)

Training code:

criterion = nn.CTCLoss(blank=0, reduction='mean')
    # with autograd.detect_anomaly():
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        input_len = torch.tensor([output.size(0)], dtype=torch.int)
        target_len = torch.tensor([target.size(1)], dtype=torch.int)
        log_probs = nn.functional.log_softmax(output, dim=2)
        loss = criterion(log_probs, target, input_len, target_len)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

LSTM full code:

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor
import numbers

def script_lstm(input_size, hidden_size, num_layers, bias=True,
                batch_first=False, dropout=False, bidirectional=True):
    
    assert bias
    assert not batch_first
    stack_type = StackedLSTM2
    layer_type = BidirLSTMLayer
    dirs = 2
    return stack_type(num_layers, layer_type,
                      first_layer_args=[LSTMCell, input_size, hidden_size],
                      other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size])

def reverse(lst):
    # type: (List[Tensor]) -> List[Tensor]
    return lst[::-1]

class LSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size, order=1):
        # __constants__ = ['order']
        super(LSTMCell, self).__init__()

        self.order = order
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

        ###weight-bias for st-1, eq.6,7 for N=0###
        self.weight_ch_prev = Parameter(torch.randn(2 * hidden_size, hidden_size))
        self.bias_ch_prev = Parameter(torch.randn(2 * hidden_size))

        ###weight-bias for d(st-1), eq.6,7 for N=1###
        self.weight_ch_dc_prev = Parameter(torch.randn(2 * self.order * hidden_size,hidden_size))
        self.bias_ch_dc_prev = Parameter(torch.randn(2 * self.order  * hidden_size))

        ###weight-bias for st, eq.8 for N=0###
        self.weight_ch_cur = Parameter(torch.randn(hidden_size, hidden_size))
        self.bias_ch_cur = Parameter(torch.randn(hidden_size))

        ###weight-bias for d(st-1), eq.8 for N=1
        self.weight_ch_dc_cur = Parameter(torch.randn(self.order * hidden_size,hidden_size))
        self.bias_ch_dc_cur = Parameter(torch.randn(self.order * hidden_size))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
        hx, cx, dc = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        gates_2 = (torch.mm(dc, self.weight_ch_dc_prev.t()) + self.bias_ch_dc_prev + torch.mm(cx, self.weight_ch_prev.t()) + self.bias_ch_prev)
        ingate_2, forgetgate_2 = gates_2.chunk(2, 1)
        ingate = ingate + ingate_2
        forgetgate = forgetgate + forgetgate_2

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        outgate = outgate + (torch.mm(cy-cx, self.weight_ch_dc_cur.t()) + self.bias_ch_dc_cur + torch.mm(cy, self.weight_ch_cur.t()) + self.bias_ch_cur )
        outgate = torch.sigmoid(outgate)
        hy = outgate * torch.tanh(cy)
        d_c = cy - cx
        return hy, (hy, cy, d_c)
        
class LSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
        inputs = input.unbind(0)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs), state

class ReverseLSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(ReverseLSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
        inputs = reverse(input.unbind(0))
        outputs = jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(reverse(outputs)), state

class BidirLSTMLayer(jit.ScriptModule):
    __constants__ = ['directions']

    def __init__(self, cell, *cell_args):
        super(BidirLSTMLayer, self).__init__()
        self.directions = nn.ModuleList([
            LSTMLayer(cell, *cell_args),
            ReverseLSTMLayer(cell, *cell_args),
        ])
    @jit.script_method
    def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor, Tensor]]]
        # List[LSTMState]: [forward LSTMState, backward LSTMState]
        outputs = jit.annotate(List[Tensor], [])
        output_states = jit.annotate(List[Tuple[Tensor, Tensor, Tensor]], [])
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for direction in self.directions:
            state = states[i]
            out, out_state = direction(input, state)
            outputs += [out]
            output_states += [out_state]
            i += 1
        return torch.cat(outputs, -1), output_states

def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):

    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
                                           for _ in range(num_layers - 1)]

    return nn.ModuleList(layers)

class StackedLSTM2(jit.ScriptModule):
    __constants__ = ['layers'] 

    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
        super(StackedLSTM2, self).__init__()

        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
                                        other_layer_args)
    @jit.script_method
    def forward(self, input, states):
        # type: (Tensor, List[List[Tuple[Tensor, Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor, Tensor]]]]
        # List[List[LSTMState]]: The outer list is for layers,
        #                        inner list is for directions.
        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor, Tensor]]], [])
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
            i += 1
        return output, output_states

Thanks again.

When i use “with autograd.detect_anomaly()” i get the following error messages:


RuntimeError: Function 'Sigmoidbackward' returned  nan values in its 0th output

RuntimeError: Function 'DivBackward0' returned nan values in its 0th output

RuntimeError: Function 'CudnnConvolutionBackward' returned nan values in its 0th output

RuntimeError: Function torch::jit::(anonymous namespace)::DifferentiableGraphBackward returned nan values in its 2th output

Can anyone tell me if it is caused by overflow/division by zero ,etc.?
Edit: I narrowed it down, the problem is in the following lines, but still can’t resolve it:

ingate = ingate + ingate_2
forgetgate = forgetgate + forgetgate_2
outgate = outgate + (torch.mm(cy-cx, self.weight_ch_dc_cur.t()) + self.bias_ch_dc_cur + torch.mm(cy, self.weight_ch_cur.t()) + self.bias_ch_cur )

Hi Theocharis,

I am having similar issue. I am trying to run the existing LSTM code from fastrnn provided in github, but when I am back-propagating I get similar error. Did you get a chance to resolve this error? The same input runs fine for native LSTM.

Hey Tejo,
as far as i can recall, i couldn’t resolve this error :confused:

Hi,
Have you resolve this issue? I am facing the same problem. I have posted my issue [here] (Reproducing a code using Residual LSTM but getting Nan values in gradients).