Port weights from keras.layers.CuDNNGRU into torch.nn.GRU

Hello,

Sorry in advance as I know this is not exactly PyTorch specific, as the bug may reside in Keras, see this thread where I actually asked for help using the same message than here

Here is my usecase: I am trying to load weights from a trained keras.layers.CuDNNGRU into a torch.nn.GRU (I know, risky business from the get go). Problem: even after loading the weights, seemingly correctly, both implementations return widely different results.

You will find in this gist some code to reproduce the error, together with a Dockerfile to reproduce my environment exactly:

  • Ubuntu 16.04
  • CUDA 8.0
  • CuDNN 5.x
  • torch 0.4.1
  • keras 2.2.4
  • tensorflow-gpu 1.2.0

You will find in the gist two script two scripts which should show you that I cannot make it work in any direction:

  1. creating torch GRU, load its weights into CudNNGRU and check for equality
  2. creating CudNNGRU, load its weights into torch GRU and check for equality

Note that the structure of the keras.layers.CudNNGRU weights when using get_weights is as follows:

  1. input matrix of shape (input_dimension, 3 * hidden_dimension)
  2. recurrent matrix of shape (hidden_dimension, 3 * hidden_dimension)
  3. bias of shape (2 * 3 * hidden_dimension) where the first half is the input bias, second half is the recurrent bias (that part I am unsure of, cannot understand it from Keras’ source code)

The structure of the torch.nn.GRU weights is as follows:

  1. input matrix torch.nn.GRU.weight_ih_l0 has shape (3 * hidden_dimension, input_dimension)
  2. recurrent matrix torch.nn.GRU.weight_hh_l0 has shape (3 * hidden_dimension, hidden_dimension)
  3. input bias torch.nn.GRU.bias_ih_l0 has shape (3 * hidden_dimension,)
  4. recurrent bias torch.nn.GRU.bias_hh_l0 has shape (3 * hidden_dimension,)

EDIT 1: try importing torch before tensorflow in my scripts and see what happens!! complete mayhem, with some sort of buffer / stack overflow. Here is an excerpt of the trace:

7fc8aa55a000-7fc8aa59a000 rw-p 00000000 00:00 0 
7fc8aa59a000-7fc8aa5c1000 r--p 00000000 00:9a 60                         /usr/lib/locale/C.UTF-8/LC_CTYPE
7fc8aa5c1000-7fc8aa5c2000 r--p 00000000 00:9a 59                         /usr/lib/locale/C.UTF-8/LC_NUMERIC
7fc8aa5c2000-7fc8aa5c3000 r--p 00000000 00:9a 58                         /usr/lib/locale/C.UTF-8/LC_TIME
7fc8aa5c3000-7fc8aa735000 r--p 00000000 00:9a 57                         /usr/lib/locale/C.UTF-8/LC_COLLATE
7fc8aa735000-7fc8aa736000 r--p 00000000 00:9a 56                         /usr/lib/locale/C.UTF-8/LC_MONETARY
7fc8aa736000-7fc8aa737000 r--p 00000000 00:9a 55                         /usr/lib/locale/C.UTF-8/LC_MESSAGES/SYS_LC_MESSAGES
7fc8aa737000-7fc8aa73e000 r--s 00000000 00:9a 48                         /usr/lib/x86_64-linux-gnu/gconv/gconv-modules.cache
7fc8aa73e000-7fc8aa744000 rw-p 00000000 00:00 0 
7fc8aa744000-7fc8aa745000 r--p 00000000 00:9a 53                         /usr/lib/locale/C.UTF-8/LC_PAPER
7fc8aa745000-7fc8aa746000 r--p 00000000 00:9a 52                         /usr/lib/locale/C.UTF-8/LC_NAME
7fc8aa746000-7fc8aa747000 r--p 00000000 00:9a 51                         /usr/lib/locale/C.UTF-8/LC_ADDRESS
7fc8aa747000-7fc8aa748000 r--p 00000000 00:9a 50                         /usr/lib/locale/C.UTF-8/LC_TELEPHONE
7fc8aa748000-7fc8aa749000 r--p 00000000 00:9a 49                         /usr/lib/locale/C.UTF-8/LC_MEASUREMENT
7fc8aa749000-7fc8aa74a000 r--p 00000000 00:9a 45                         /usr/lib/locale/C.UTF-8/LC_IDENTIFICATION
7fc8aa74a000-7fc8aa74b000 r--p 00025000 00:9a 32                         /lib/x86_64-linux-gnu/ld-2.23.so
7fc8aa74b000-7fc8aa74c000 rw-p 00026000 00:9a 32                         /lib/x86_64-linux-gnu/ld-2.23.so
7fc8aa74c000-7fc8aa74d000 rw-p 00000000 00:00 0 
7ffde0031000-7ffde0058000 rw-p 00000000 00:00 0                          [stack]
7ffde01e1000-7ffde01e4000 r--p 00000000 00:00 0                          [vvar]
7ffde01e4000-7ffde01e6000 r-xp 00000000 00:00 0                          [vdso]
ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0                  [vsyscall]
Aborted (core dumped)

Now I am wondering if there is not a memory leak triggering the difference between the two classes in the first place!!

Hello, know that this was solved by Yuyang Huang in the following Keras issue. Solution is:

Keras -> torch:

def convert_input_kernel(kernel):
    kernel_z, kernel_r, kernel_h = np.hsplit(kernel, 3)
    kernels = [kernel_r, kernel_z, kernel_h]
    return np.vstack([k.reshape(k.T.shape) for k in kernels])

def convert_recurrent_kernel(kernel):
    kernel_z, kernel_r, kernel_h = np.hsplit(kernel, 3)
    kernels = [kernel_r, kernel_z, kernel_h]
    return np.vstack(kernels)

def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape(-1)

torch -> Keras (just the reverse transformation):

def convert_input_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3)
    kernels = [kernel_z, kernel_r, kernel_h]
    return np.hstack([k.reshape(k.T.shape) for k in kernels])

def convert_recurrent_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3)
    kernels = [kernel_z, kernel_r, kernel_h]
    return np.hstack(kernels)

def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape(-1)
2 Likes