So I assume CudnnRNN and LSTMFused still do not support second order gradient? PyTorch seemed to work when I wrote some toy cases, but it failed and threw me “CudnnRNN is not differentiable twice” error whenever I try something real (like implementing gradient norm regularizer for GAN with RNN discriminator).
I have no idea how second order gradient on CudnnRNN is supposed to work, so I made a patch that adds another keyword argument called fused
, which toggles whether to use CudnnRNN or LSTMFused if applicable:
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 0fbbaf2..1bcce7e 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -26,7 +26,10 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused()
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)
+ return LSTMUnfusedCell(input, hidden, w_ih, w_hh, b_ih, b_hh)
+
+def LSTMUnfusedCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
@@ -50,7 +53,10 @@ def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
gh = F.linear(hidden, w_hh)
state = fusedBackend.GRUFused()
return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh)
+ return GRUUnfusedCell(input, hidden, w_ih, w_hh, b_ih, b_hh)
+
+def GRUUnfusedCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
gi = F.linear(input, w_ih, b_ih)
gh = F.linear(hidden, w_hh, b_hh)
i_r, i_i, i_n = gi.chunk(3, 1)
@@ -208,16 +214,16 @@ def VariableRecurrentReverse(batch_sizes, inner):
def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, train=True, bidirectional=False, batch_sizes=None,
- dropout_state=None, flat_weight=None):
+ dropout_state=None, flat_weight=None, fused=True):
if mode == 'RNN_RELU':
cell = RNNReLUCell
elif mode == 'RNN_TANH':
cell = RNNTanhCell
elif mode == 'LSTM':
- cell = LSTMCell
+ cell = LSTMCell if fused else LSTMUnfusedCell
elif mode == 'GRU':
- cell = GRUCell
+ cell = GRUCell if fused else GRUUnfusedCell
else:
raise Exception('Unknown mode: {}'.format(mode))
@@ -255,7 +261,8 @@ class CudnnRNN(NestedIOFunction):
def __init__(self, mode, input_size, hidden_size, num_layers=1,
batch_first=False, dropout=0, train=True, bidirectional=False,
- batch_sizes=None, dropout_state=None, flat_weight=None):
+ batch_sizes=None, dropout_state=None, flat_weight=None,
+ fused=True):
super(CudnnRNN, self).__init__()
if dropout_state is None:
dropout_state = {}
@@ -344,7 +351,7 @@ class CudnnRNN(NestedIOFunction):
def RNN(*args, **kwargs):
def forward(input, *fargs, **fkwargs):
- if cudnn.is_acceptable(input.data):
+ if cudnn.is_acceptable(input.data) and kwargs.get('fused', True):
func = CudnnRNN(*args, **kwargs)
else:
func = AutogradRNN(*args, **kwargs)
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 5f36278..152fa2c 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -11,7 +11,7 @@ class RNNBase(Module):
def __init__(self, mode, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
- dropout=0, bidirectional=False):
+ dropout=0, bidirectional=False, fused=True):
super(RNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
@@ -22,6 +22,7 @@ class RNNBase(Module):
self.dropout = dropout
self.dropout_state = {}
self.bidirectional = bidirectional
+ self.fused = fused
num_directions = 2 if bidirectional else 1
if mode == 'LSTM':
@@ -155,7 +156,8 @@ class RNNBase(Module):
bidirectional=self.bidirectional,
batch_sizes=batch_sizes,
dropout_state=self.dropout_state,
- flat_weight=flat_weight
+ flat_weight=flat_weight,
+ fused=self.fused,
)
output, hidden = func(input, self.all_weights, hx)
if is_packed:
@@ -318,6 +320,8 @@ class LSTM(RNNBase):
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False
+ fused: If True, try to use GPU-accelerated version if possible (faster
+ but cannot compute second order derivative). Default: True
Inputs: input, (h_0, c_0)
- **input** (seq_len, batch, input_size): tensor containing the features
@@ -395,6 +399,8 @@ class GRU(RNNBase):
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False
+ fused: If True, try to use GPU-accelerated version if possible (faster
+ but cannot compute second order derivative). Default: True
Inputs: input, h_0
- **input** (seq_len, batch, input_size): tensor containing the features