I am new to Pytorch and would appreciate some direction on how to create and use an LSTM cell with multiple additional gates. For example I would like to implement the LSTM cell described in the this paper
You just take an LSTMCell implementation and modify it.
For example, you can copy the implementation here and modify it for your own purposes:
very much appreciated!
just a note that fastrnn module seems to only work with pytorch v 1.0.0
I am having a little difficulty getting the implementation of custom lstm cell a la https://github.com/pytorch/benchmark/tree/master/rnns/fastrnns to work. Please my code with minimal implementation below
custom RNN
import torch
from collections import namedtuple
'''
Define a creator as a function:
(options) -> (inputs, params, forward, backward_setup, backward)
inputs: the inputs to the returned 'forward'. One can call
forward(*inputs) directly.
params: List[Tensor] all requires_grad=True parameters.
forward: function / graph executor / module
One can call rnn(rnn_inputs) using the outputs of the creator.
backward_setup: backward_inputs = backward_setup(*outputs)
Then, we pass backward_inputs to backward. If None, then it is assumed to
be the identity function.
backward: Given `output = backward_setup(*forward(*inputs))`, performs
backpropagation. If None, then nothing happens.
fastrnns.bench times the forward and backward invocations.
'''
def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
hx, cx = hidden
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
def varlen_lstm_creator(script=False, **kwargs):
sequences, _, hidden, params, _ = varlen_lstm_inputs(
return_module=False, **kwargs)
inputs = [sequences, hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=varlen_lstm_factory(lstm_cell, script),
backward_setup=varlen_lstm_backward_setup,
backward=simple_backward)
ModelDef = namedtuple('ModelDef', [
'inputs', 'params', 'forward', 'backward_setup', 'backward'])
def varlen_lstm_inputs(minlen=30, maxlen=100,
numLayers=1, inputSize=512, hiddenSize=512,
miniBatch=64, return_module=False, device='cpu',
seed=None, **kwargs):
if seed is not None:
torch.manual_seed(seed)
lengths = torch.randint(
low=minlen, high=maxlen, size=[miniBatch],
dtype=torch.long, device=device)
x = [torch.randn(length, inputSize, device=device)
for length in lengths]
hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device)
if return_module:
return x, lengths, (hx, cx), lstm.all_weights, lstm
else:
# NB: lstm.all_weights format:
# wih, whh, bih, bhh = lstm.all_weights[layer]
return x, lengths, (hx, cx), lstm.all_weights, None
def varlen_lstm_backward_setup(forward_output, seed=None):
if seed:
torch.manual_seed(seed)
rnn_utils = torch.nn.utils.rnn
sequences = forward_output[0]
padded = rnn_utils.pad_sequence(sequences)
grad = torch.randn_like(padded)
return padded, grad
def varlen_lstm_factory(cell, script):
def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh):
# type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]
hx, cx = hiddens
hxs = hx.unbind(1)
cxs = cx.unbind(1)
# List of: (output, hx, cx)
outputs = []
hx_outs = []
cx_outs = []
for batch in range(len(sequences)):
output = []
hy, cy = hxs[batch], cxs[batch]
inputs = sequences[batch].unbind(0)
for seq_idx in range(len(inputs)):
hy, cy = cell(
inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh)
output += [hy]
outputs += [torch.stack(output)]
hx_outs += [hy.unsqueeze(0)]
cx_outs += [cy.unsqueeze(0)]
return outputs, (hx_outs, cx_outs)
if script:
cell = torch.jit.script(cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn
def simple_backward(output, grad_output):
return output.backward(grad_output)
# list[list[T]] -> list[T]
def flatten_list(lst):
result = []
for inner in lst:
result.extend(inner)
return result
a basic POF
import torch
import torch.nn as nn
from CustomRNN import varlen_lstm_creator
torch.manual_seed(1)
lstm = varlen_lstm_creator(inputSize=3, hiddenSize=3) # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 5
# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3))
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)
Error
TypeError: ‘ModelDef’ object is not callable
Any help is grealty appreciated!
The whole file / helper functions in there are not callable as-is, it’s part of a larger benchmark harness.
I meant that you probably want to just grab the cell functions themselves.
OK. But how do I integrate a cell function with for example torch.nn.LSTM? I appreciate your time.
you dont integrate a cell function with torch.nn.LSTM
, you write a for loop around it.
Something like:
for t in range(timesteps):
for b in range(batches):
h, c = F.lstm_cell(input, weight, bias, h, c)
out = F.linear(weight_out, h)
This would be an example of such a loop that you would write: https://github.com/pytorch/benchmark/blob/09eaadc1d05ad442b1f0beb82babf875bbafb24b/rnns/fastrnns/factory.py#L165-L182
I was able to create an use a number of custom RNN/ LSTM cell using https://github.com/NVIDIA/apex/tree/master/apex/RNN.
I am struggling to with a custom lstm cell that returns tuples for both the hidden and cell state. Please find minimal code below. Any help is appreciated.
class skipLSTMRNNCell(RNNCell):
def __init__(self, input_size, hidden_size, bias=False, output_size=None):
gate_multiplier = 4
super(skipLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, SkipLSTMCell, n_hidden_states=4,
bias=bias, output_size=output_size)
self.w_uh = Parameter(xavier_uniform(torch.Tensor(1, hidden_size)))
self.b_uh = Parameter(torch.ones(1))
self.reset_parameters()
def forward(self, input):
# if not inited or bsz has changed this will create hidden states
self.init_hidden(input.size()[0])
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
self.hidden = list(
self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_uh, self.b_uh,
b_ih=self.b_ih, b_hh=self.b_hh)
)
if self.output_size != self.hidden_size:
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
return tuple(self.hidden)
def new_like(self, new_input_size=None):
if new_input_size is None:
new_input_size = self.input_size
return type(self)(
new_input_size,
self.hidden_size,
self.bias,
self.output_size)
cell
def SkipLSTMCell(input, hidden, w_ih, w_hh, w_uh, b_uh=None, b_ih=None, b_hh=None):
# if num_layers != 1:
# raise RuntimeError("wrong num_layers: got {}, expected {}".format(num_layers, 1))
# w_ih, w_hh = w_ih[0], w_hh[0]
# b_ih = b_ih[0] if b_ih is not None else None
# b_hh = b_hh[0] if b_hh is not None else None
c_prev, h_prev, update_prob_prev, cum_update_prob_prev = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(h_prev, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = F.sigmoid(outgate)
new_c_tilde = (forgetgate * c_prev) + (ingate * cellgate)
new_h_tilde = outgate * torch.tanh(new_c_tilde)
# Compute value for the update prob
new_update_prob_tilde = torch.sigmoid(F.linear(new_c_tilde, w_uh, b_uh))
# Compute value for the update gate
cum_update_prob = cum_update_prob_prev + torch.min(update_prob_prev, 1. - cum_update_prob_prev)
# round
bn = BinaryLayer()
update_gate = bn(cum_update_prob)
# Apply update gate
new_c = update_gate * new_c_tilde + (1. - update_gate) * c_prev
new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev
new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob
new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev
new_state = (new_c, new_h, new_update_prob, new_cum_update_prob)
new_output = (new_h, update_gate)
return new_output, new_state
Can anyone explain what this is line doing?
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)