Does pytorch support double backwards in RNN?


(赖文泽) #1

I am building an improved-wasserstein style GAN, both generator and discriminator are RNN,
all is fine , but at the stage of calculate the gradient_penalty, I got some error:

  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py", line 156, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/__init__.py", line 98, in backward
    variables, grad_variables, retain_graph)
RuntimeError: CudnnRNN is not differentiable twice

just like calc_gradient_penalty

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(BATCH_SIZE, 300 , 33)
    print(real_data.size())
    # alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda()
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates = Variable(interpolates, requires_grad=True)
    disc_interpolates = netD(interpolates)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

G and D:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.lstm1 = nn.LSTM(
            input_size=128,
            hidden_size=33,
            num_layers=2,
            batch_first=True,
        )
        # self.out = nn.Linear(64, 33)
    
    def forward(self, x):
        r_out, (h_n, h_C) = self.lstm1(x, None)
        return r_out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lstm1 = nn.LSTM(
            input_size=33,
            hidden_size=64,
            num_layers=2,
            batch_first=True,
        )
        self.out = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid()
            )


    def forward(self, x):
        r_out, (h_n, h_c) = self.lstm1(x, None)
        out = self.out(r_out[:, -1, :])
        return out

the input/output has the shape [batch_size, 300, 33], noise [batch_size, 300, 128]

Does any one know something?


(Tom Sercu) #2

cuDNN provides very fast primitives, but without access to the internals.
If you were to implement the RNN in plain pytorch (see tutorials & example scripts) pytorch has access to the internal buffers which are needed for using the gradients in the objective.


(Quan Gan) #3

So I assume CudnnRNN and LSTMFused still do not support second order gradient? PyTorch seemed to work when I wrote some toy cases, but it failed and threw me “CudnnRNN is not differentiable twice” error whenever I try something real (like implementing gradient norm regularizer for GAN with RNN discriminator).

I have no idea how second order gradient on CudnnRNN is supposed to work, so I made a patch that adds another keyword argument called fused, which toggles whether to use CudnnRNN or LSTMFused if applicable:

diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 0fbbaf2..1bcce7e 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -26,7 +26,10 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
         hgates = F.linear(hidden[0], w_hh)
         state = fusedBackend.LSTMFused()
         return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)
+    return LSTMUnfusedCell(input, hidden, w_ih, w_hh, b_ih, b_hh)
 
+
+def LSTMUnfusedCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
     hx, cx = hidden
     gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
 
@@ -50,7 +53,10 @@ def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
         gh = F.linear(hidden, w_hh)
         state = fusedBackend.GRUFused()
         return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh)
+    return GRUUnfusedCell(input, hidden, w_ih, w_hh, b_ih, b_hh)
+
 
+def GRUUnfusedCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
     gi = F.linear(input, w_ih, b_ih)
     gh = F.linear(hidden, w_hh, b_hh)
     i_r, i_i, i_n = gi.chunk(3, 1)
@@ -208,16 +214,16 @@ def VariableRecurrentReverse(batch_sizes, inner):
 
 def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
                 dropout=0, train=True, bidirectional=False, batch_sizes=None,
-                dropout_state=None, flat_weight=None):
+                dropout_state=None, flat_weight=None, fused=True):
 
     if mode == 'RNN_RELU':
         cell = RNNReLUCell
     elif mode == 'RNN_TANH':
         cell = RNNTanhCell
     elif mode == 'LSTM':
-        cell = LSTMCell
+        cell = LSTMCell if fused else LSTMUnfusedCell
     elif mode == 'GRU':
-        cell = GRUCell
+        cell = GRUCell if fused else GRUUnfusedCell
     else:
         raise Exception('Unknown mode: {}'.format(mode))
 
@@ -255,7 +261,8 @@ class CudnnRNN(NestedIOFunction):
 
     def __init__(self, mode, input_size, hidden_size, num_layers=1,
                  batch_first=False, dropout=0, train=True, bidirectional=False,
-                 batch_sizes=None, dropout_state=None, flat_weight=None):
+                 batch_sizes=None, dropout_state=None, flat_weight=None,
+                 fused=True):
         super(CudnnRNN, self).__init__()
         if dropout_state is None:
             dropout_state = {}
@@ -344,7 +351,7 @@ class CudnnRNN(NestedIOFunction):
 
 def RNN(*args, **kwargs):
     def forward(input, *fargs, **fkwargs):
-        if cudnn.is_acceptable(input.data):
+        if cudnn.is_acceptable(input.data) and kwargs.get('fused', True):
             func = CudnnRNN(*args, **kwargs)
         else:
             func = AutogradRNN(*args, **kwargs)
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 5f36278..152fa2c 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -11,7 +11,7 @@ class RNNBase(Module):
 
     def __init__(self, mode, input_size, hidden_size,
                  num_layers=1, bias=True, batch_first=False,
-                 dropout=0, bidirectional=False):
+                 dropout=0, bidirectional=False, fused=True):
         super(RNNBase, self).__init__()
         self.mode = mode
         self.input_size = input_size
@@ -22,6 +22,7 @@ class RNNBase(Module):
         self.dropout = dropout
         self.dropout_state = {}
         self.bidirectional = bidirectional
+        self.fused = fused
         num_directions = 2 if bidirectional else 1
 
         if mode == 'LSTM':
@@ -155,7 +156,8 @@ class RNNBase(Module):
             bidirectional=self.bidirectional,
             batch_sizes=batch_sizes,
             dropout_state=self.dropout_state,
-            flat_weight=flat_weight
+            flat_weight=flat_weight,
+            fused=self.fused,
         )
         output, hidden = func(input, self.all_weights, hx)
         if is_packed:
@@ -318,6 +320,8 @@ class LSTM(RNNBase):
         dropout: If non-zero, introduces a dropout layer on the outputs of each
             RNN layer except the last layer
         bidirectional: If True, becomes a bidirectional RNN. Default: False
+        fused: If True, try to use GPU-accelerated version if possible (faster
+            but cannot compute second order derivative).  Default: True
 
     Inputs: input, (h_0, c_0)
         - **input** (seq_len, batch, input_size): tensor containing the features
@@ -395,6 +399,8 @@ class GRU(RNNBase):
         dropout: If non-zero, introduces a dropout layer on the outputs of each
             RNN layer except the last layer
         bidirectional: If True, becomes a bidirectional RNN. Default: False
+        fused: If True, try to use GPU-accelerated version if possible (faster
+            but cannot compute second order derivative).  Default: True
 
     Inputs: input, h_0
         - **input** (seq_len, batch, input_size): tensor containing the features

#4

right now they double backward on CuDNNRNN and fused cells is not supported. we are still working on getting those complete.


(Koustuv Sinha) #5

hi @smth, is the implementation still in the works?