Hi Jack and thanks for your reply.First of all what i meant was that i got nan after two or three batches, not epochs(reducing learning rate wouldn’t work).What i modified to the original code is that i added one state (the derivative of cell state, dc) and the computation of gates of
course, as you can see above, and consequently i can’t understand why it should return nan (i can run the original code with my model and training code without a problem)
Training code:
criterion = nn.CTCLoss(blank=0, reduction='mean')
# with autograd.detect_anomaly():
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
input_len = torch.tensor([output.size(0)], dtype=torch.int)
target_len = torch.tensor([target.size(1)], dtype=torch.int)
log_probs = nn.functional.log_softmax(output, dim=2)
loss = criterion(log_probs, target, input_len, target_len)
train_loss += loss.item()
loss.backward()
optimizer.step()
LSTM full code:
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor
import numbers
def script_lstm(input_size, hidden_size, num_layers, bias=True,
batch_first=False, dropout=False, bidirectional=True):
assert bias
assert not batch_first
stack_type = StackedLSTM2
layer_type = BidirLSTMLayer
dirs = 2
return stack_type(num_layers, layer_type,
first_layer_args=[LSTMCell, input_size, hidden_size],
other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size])
def reverse(lst):
# type: (List[Tensor]) -> List[Tensor]
return lst[::-1]
class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size, order=1):
# __constants__ = ['order']
super(LSTMCell, self).__init__()
self.order = order
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
###weight-bias for st-1, eq.6,7 for N=0###
self.weight_ch_prev = Parameter(torch.randn(2 * hidden_size, hidden_size))
self.bias_ch_prev = Parameter(torch.randn(2 * hidden_size))
###weight-bias for d(st-1), eq.6,7 for N=1###
self.weight_ch_dc_prev = Parameter(torch.randn(2 * self.order * hidden_size,hidden_size))
self.bias_ch_dc_prev = Parameter(torch.randn(2 * self.order * hidden_size))
###weight-bias for st, eq.8 for N=0###
self.weight_ch_cur = Parameter(torch.randn(hidden_size, hidden_size))
self.bias_ch_cur = Parameter(torch.randn(hidden_size))
###weight-bias for d(st-1), eq.8 for N=1
self.weight_ch_dc_cur = Parameter(torch.randn(self.order * hidden_size,hidden_size))
self.bias_ch_dc_cur = Parameter(torch.randn(self.order * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
hx, cx, dc = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
gates_2 = (torch.mm(dc, self.weight_ch_dc_prev.t()) + self.bias_ch_dc_prev + torch.mm(cx, self.weight_ch_prev.t()) + self.bias_ch_prev)
ingate_2, forgetgate_2 = gates_2.chunk(2, 1)
ingate = ingate + ingate_2
forgetgate = forgetgate + forgetgate_2
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
cy = (forgetgate * cx) + (ingate * cellgate)
outgate = outgate + (torch.mm(cy-cx, self.weight_ch_dc_cur.t()) + self.bias_ch_dc_cur + torch.mm(cy, self.weight_ch_cur.t()) + self.bias_ch_cur )
outgate = torch.sigmoid(outgate)
hy = outgate * torch.tanh(cy)
d_c = cy - cx
return hy, (hy, cy, d_c)
class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(LSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state
class ReverseLSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(ReverseLSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
inputs = reverse(input.unbind(0))
outputs = jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(reverse(outputs)), state
class BidirLSTMLayer(jit.ScriptModule):
__constants__ = ['directions']
def __init__(self, cell, *cell_args):
super(BidirLSTMLayer, self).__init__()
self.directions = nn.ModuleList([
LSTMLayer(cell, *cell_args),
ReverseLSTMLayer(cell, *cell_args),
])
@jit.script_method
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor, Tensor]]]
# List[LSTMState]: [forward LSTMState, backward LSTMState]
outputs = jit.annotate(List[Tensor], [])
output_states = jit.annotate(List[Tuple[Tensor, Tensor, Tensor]], [])
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for direction in self.directions:
state = states[i]
out, out_state = direction(input, state)
outputs += [out]
output_states += [out_state]
i += 1
return torch.cat(outputs, -1), output_states
def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
for _ in range(num_layers - 1)]
return nn.ModuleList(layers)
class StackedLSTM2(jit.ScriptModule):
__constants__ = ['layers']
def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
super(StackedLSTM2, self).__init__()
self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
other_layer_args)
@jit.script_method
def forward(self, input, states):
# type: (Tensor, List[List[Tuple[Tensor, Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor, Tensor]]]]
# List[List[LSTMState]]: The outer list is for layers,
# inner list is for directions.
output_states = jit.annotate(List[List[Tuple[Tensor, Tensor, Tensor]]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
i += 1
return output, output_states
Thanks again.