Loading TensorFlow GRUCell weights into PyTorch

The checkpoint for Slot Attention object discovery model: google-research/slot_attention at master · google-research/google-research · GitHub (https://console.cloud.google.com/storage/browser/gresearch/slot-attention) has a tensorflow.keras.layers.GRUCell (from TFv2).

Corresponding keys (with shapes) in the checkpoint are:

{
  'network/layer_with_weights-0/slot_attention/gru/recurrent_kernel/.ATTRIBUTES/VARIABLE_VALUE' : torch.Size([64, 192]), 
  'network/layer_with_weights-0/slot_attention/gru/kernel/.ATTRIBUTES/VARIABLE_VALUE' : torch.Size([64, 192]),
  'network/layer_with_weights-0/slot_attention/gru/bias/.ATTRIBUTES/VARIABLE_VALUE' : torch.Size([2, 192])
}

Parameters for PyTorch one seem to be

{
  'slot_attention.gru.weight_ih': torch.Size([192, 64]),
  'slot_attention.gru.weight_hh': torch.Size([192, 64]),
  'slot_attention.gru.bias_ih': torch.Size([192]), 
  'slot_attention.gru.bias_hh': torch.Size([192])
}

How does one convert them to PyTorch GRUCell (especially in the view of
GRUcell is different in pytorch and tensorflow)?

Would the following be correct (in pseudo-code)?:

weight_hh = recurrent_kernel.t()
weight_ih = kernel.t()
bias_ih, bias_hh = bias.unbind()

What should be the bias format? (maybe related)

Thank you!

Hey @vadimkantorov , were you able to figure this out? I am facing exactly the same issue. Thanks.

I don’t remember for sure now, but I think my snippet above is correct. Best would be to test against TF outputs, but I don’t have the setup…

I’ve tested it against TensorFlow/Keras outputs, it does not work (output is different). I’ve found following conversion to work for GRU layer (TensorFlow r2.4rc0 - tf.keras.layers.GRU to PyTorch 1.9.0 - torch.nn.GRU):

  1. Save weights from TensorFlow model (for example to .npz file):
import random
import numpy as np
import tensorflow as tf

SEED=1995

np.random.seed(SEED)
random.seed(SEED)
tf.random.set_seed(SEED)

gru = tf.keras.layers.GRU(
    units=5,
    return_sequences=True,
    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=SEED),
    recurrent_initializer=tf.keras.initializers.Orthogonal(seed=SEED),
    bias_initializer=tf.keras.initializers.Zeros
)

y_tf = gru(tf.ones((1, 3, 5)), training=False)  # forward pass with ones

np.savez(
    'tf_model_weights.npz', 
    gru_kernel=gru.weights[0].numpy(), 
    gru_recurrent_kernel=gru.weights[1].numpy(),
    gru_bias=gru.weights[2].numpy()
)
  1. Load weights into PyTorch model:
import random as r
import numpy as np
import torch

SEED=1995
torch.set_printoptions(precision=8)

r.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

from speechbrain.nnet.RNN import GRU, LSTM

npz_weights = np.load('tf_model_weights.npz')


def convert_input_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))
    

def convert_recurrent_kernel(kernel):
    kernel_r, kernel_z, kernel_h = np.hsplit(kernel, 3)
    return np.concatenate((kernel_z.T, kernel_r.T, kernel_h.T))


def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape((2, -1))


gru = torch.nn.GRU(
    hidden_size=5,
    input_size=5,
    num_layers=1,
    bidirectional=False,
    batch_first=True
)
for pn, p in gru.named_parameters():
    if 'weight_ih' in pn:
        p.data = torch.from_numpy(convert_input_kernel(npz_weights['gru_kernel']))
    elif 'weight_hh' in pn:
        p.data = torch.from_numpy(convert_recurrent_kernel(npz_weights['gru_recurrent_kernel']))
    elif 'bias_ih' in pn:
        p.data = torch.from_numpy(convert_bias(npz_weights['gru_bias'])[0])
    else:
        p.data = torch.from_numpy(convert_bias(npz_weights['gru_bias'])[1])
  1. Test output:

TensorFlow:

>>> y_tf
    <tf.Tensor: shape=(1, 3, 5), dtype=float32, numpy=
    array([[[-0.36656722, -0.4693069 , -0.16722648,  0.36081928, 0.1643753 ],
            [-0.4628504 , -0.6815055 , -0.18605384,  0.58125013, 0.2494137 ],
            [-0.48067108, -0.7698146 , -0.16238967,  0.70518744, 0.3005259 ]]], dtype=float32)>

PyTorch:

>>> gru.eval()
    GRU(5, 5, batch_first=True)
>>> y_pt, _ = gru(torch.ones(1, 3, 5))
>>> y_pt
    tensor([[[-0.36656719, -0.46930692, -0.16722649,  0.36081928,  0.16437532],
             [-0.46285039, -0.68150550, -0.18605389,  0.58125013,  0.24941370],
             [-0.48067111, -0.76981461, -0.16238970,  0.70518738,  0.30052590]]],
           grad_fn=<TransposeBackward1>)

Hope someone will find this helpful.

4 Likes

Hi, I am trying to do the opposite: tranferring from Pytorch to TensorFlow, I tried to revert your code as follows:

def convert_kernel(kernel):
    kernel_z, kernel_r, kernel_h = np.vsplit(kernel, 3)
    return np.concatenate((kernel_r.T, kernel_z.T, kernel_h.T), axis=1)

def convert_bias(bias):
    bias = bias.reshape(2, 3, -1) 
    return bias[:, [1, 0, 2], :].reshape((2, -1))

# GRU Layer:

kernel_input = convert_kernel(torch_params[6])
kernel_h = convert_kernel(torch_params[7])
bias = convert_bias(np.stack((torch_params[8], torch_params[9]), axis=0))

model_keras.layers[8].set_weights([kernel_input, 
                                   kernel_h, 
                                   bias])

… but this doesn’t work. Would you have an idea why please? :slight_smile:

Thanks!

1 Like

Nevermind, it works, the issue was right before that (TensorFlow Reshape not behaving the same as Pytorch view() apparently)