Hello,
Sorry in advance as I know this is not exactly PyTorch specific, as the bug may reside in Keras, see this thread where I actually asked for help using the same message than here
Here is my usecase: I am trying to load weights from a trained keras.layers.CuDNNGRU
into a torch.nn.GRU
(I know, risky business from the get go). Problem: even after loading the weights, seemingly correctly, both implementations return widely different results.
You will find in this gist some code to reproduce the error, together with a Dockerfile
to reproduce my environment exactly:
- Ubuntu 16.04
- CUDA 8.0
- CuDNN 5.x
- torch 0.4.1
- keras 2.2.4
- tensorflow-gpu 1.2.0
You will find in the gist two script two scripts which should show you that I cannot make it work in any direction:
- creating torch GRU, load its weights into CudNNGRU and check for equality
- creating CudNNGRU, load its weights into torch GRU and check for equality
Note that the structure of the keras.layers.CudNNGRU
weights when using get_weights
is as follows:
- input matrix of shape
(input_dimension, 3 * hidden_dimension)
- recurrent matrix of shape
(hidden_dimension, 3 * hidden_dimension)
- bias of shape
(2 * 3 * hidden_dimension)
where the first half is the input bias, second half is the recurrent bias (that part I am unsure of, cannot understand it from Keras’ source code)
The structure of the torch.nn.GRU
weights is as follows:
- input matrix
torch.nn.GRU.weight_ih_l0
has shape(3 * hidden_dimension, input_dimension)
- recurrent matrix
torch.nn.GRU.weight_hh_l0
has shape(3 * hidden_dimension, hidden_dimension)
- input bias
torch.nn.GRU.bias_ih_l0
has shape(3 * hidden_dimension,)
- recurrent bias
torch.nn.GRU.bias_hh_l0
has shape(3 * hidden_dimension,)
EDIT 1: try importing torch
before tensorflow
in my scripts and see what happens!! complete mayhem, with some sort of buffer / stack overflow. Here is an excerpt of the trace:
7fc8aa55a000-7fc8aa59a000 rw-p 00000000 00:00 0
7fc8aa59a000-7fc8aa5c1000 r--p 00000000 00:9a 60 /usr/lib/locale/C.UTF-8/LC_CTYPE
7fc8aa5c1000-7fc8aa5c2000 r--p 00000000 00:9a 59 /usr/lib/locale/C.UTF-8/LC_NUMERIC
7fc8aa5c2000-7fc8aa5c3000 r--p 00000000 00:9a 58 /usr/lib/locale/C.UTF-8/LC_TIME
7fc8aa5c3000-7fc8aa735000 r--p 00000000 00:9a 57 /usr/lib/locale/C.UTF-8/LC_COLLATE
7fc8aa735000-7fc8aa736000 r--p 00000000 00:9a 56 /usr/lib/locale/C.UTF-8/LC_MONETARY
7fc8aa736000-7fc8aa737000 r--p 00000000 00:9a 55 /usr/lib/locale/C.UTF-8/LC_MESSAGES/SYS_LC_MESSAGES
7fc8aa737000-7fc8aa73e000 r--s 00000000 00:9a 48 /usr/lib/x86_64-linux-gnu/gconv/gconv-modules.cache
7fc8aa73e000-7fc8aa744000 rw-p 00000000 00:00 0
7fc8aa744000-7fc8aa745000 r--p 00000000 00:9a 53 /usr/lib/locale/C.UTF-8/LC_PAPER
7fc8aa745000-7fc8aa746000 r--p 00000000 00:9a 52 /usr/lib/locale/C.UTF-8/LC_NAME
7fc8aa746000-7fc8aa747000 r--p 00000000 00:9a 51 /usr/lib/locale/C.UTF-8/LC_ADDRESS
7fc8aa747000-7fc8aa748000 r--p 00000000 00:9a 50 /usr/lib/locale/C.UTF-8/LC_TELEPHONE
7fc8aa748000-7fc8aa749000 r--p 00000000 00:9a 49 /usr/lib/locale/C.UTF-8/LC_MEASUREMENT
7fc8aa749000-7fc8aa74a000 r--p 00000000 00:9a 45 /usr/lib/locale/C.UTF-8/LC_IDENTIFICATION
7fc8aa74a000-7fc8aa74b000 r--p 00025000 00:9a 32 /lib/x86_64-linux-gnu/ld-2.23.so
7fc8aa74b000-7fc8aa74c000 rw-p 00026000 00:9a 32 /lib/x86_64-linux-gnu/ld-2.23.so
7fc8aa74c000-7fc8aa74d000 rw-p 00000000 00:00 0
7ffde0031000-7ffde0058000 rw-p 00000000 00:00 0 [stack]
7ffde01e1000-7ffde01e4000 r--p 00000000 00:00 0 [vvar]
7ffde01e4000-7ffde01e6000 r-xp 00000000 00:00 0 [vdso]
ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 [vsyscall]
Aborted (core dumped)
Now I am wondering if there is not a memory leak triggering the difference between the two classes in the first place!!