Converting tensorflow model to pytorch!

Greetings,

My data consists of time-series samples with 925 steps, each containing 2 features. In other words, my data is shaped as (samples, steps, features)

The model I’m currently implementing works in TensorFlow, but I’m having trouble properly implementing it in PyTorch class.

def sampling(samp_args):
   """
   generate z
   """
    z_mean, z_log_sigma = samp_args

    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_sigma) * epsilon

# encoder
encoder = Bidirectional(GRU(32, name="encoder1", return_sequences=True))(
    main_input
)
encoder = Dropout(0.25, name="drop_encoder1")(encoder)
encoder = Bidirectional(GRU(32, name="encoder2", return_sequences=False))(encoder)
encoder = Dropout(0.25, name="drop_encoder2")(encoder)
codings_mean = Dense(units=16, name="encoding_mean", activation="linear")(
    encoder
)
codings_log_var = Dense(
    units=16, name="encoding_log_var", activation="linear"
)(encoder)
codings = Lambda(sampling, output_shape=(output_size,))([codings_mean, codings_log_var])


# decoder
decoder = RepeatVector(lcs_scaled.shape[1], name="repeat")(codings)
decoder = merge.concatenate([aux_input, decoder])
decoder = GRU(gru_size, name="decoder1", return_sequences=True)(decoder)
decoder = Dropout(dropout_val, name="drop_decoder1")(decoder)
decoder = GRU(gru_size, name="decoder2", return_sequences=True)(decoder)
decoder = TimeDistributed(Dense(1, activation="linear"), name="time_dist")(decoder)

# VAE
model = Model(model_input, decoder)

class VariationalRecurrentAutoEncoder(nn.Module):
    def __init__(self, input_size_main=2,input_size_aux=1, gru_size=32,linear_size=16,dim_seq = 925, drop_out=0.25, use_gpu=False):
        super(VariationalRecurrentAutoEncoder, self).__init__()

        # Configuration
        self.dropout = nn.Dropout(drop_out)
        #self.dim_seq = dim_seq
        self.use_gpu = use_gpu

        # Encoder
        self.gru_encoder = nn.GRU(input_size=input_size_main,
                                hidden_size=gru_size,
                                num_layers=1,
                                batch_first=True,
                                bidirectional=True)

        self.gru_encoder2 = nn.GRU(input_size=gru_size*2,
                                hidden_size=gru_size,
                                num_layers=1,
                                batch_first=True,
                                bidirectional=True)
        # Latent space
        self.linear_mean = nn.Linear(gru_size*2, linear_size)
        self.linear_logvar= nn.Linear(gru_size*2, linear_size)

        # Decoder
        self.gru_decoder= nn.GRU(input_size=linear_size+input_size_aux,
                                hidden_size=gru_size,
                                num_layers=1,
                                batch_first=True,
                                bidirectional=False)

        self.gru_decoder2 = nn.GRU(input_size=gru_size,
                                hidden_size=gru_size,
                                num_layers=1,
                                batch_first=True,
                                bidirectional=False)

        self.timedis = TimeDistributedLayer(nn.Linear(gru_size,1),True)

    def encode(self, x):
        x, _ = self.gru_encoder(x)
        x = self.dropout(x)
        x, _ = self.gru_encoder2(x)
        x = x[:,-1]
        x = self.dropout(x)
        mean, log_var = self.linear_mean(x), self.linear_logvar(x)
        return mean, log_var
    def decode(self, z, aux):
        z = torch.unsqueeze(z, dim=1).repeat(1,self.dim_seq,1)
        cat = torch.cat((aux, z),dim=2)
        out, _ = self.gru_decoder(cat)
        out = self.dropout(out)
        out, _ = self.gru_decoder2(out)
        out = self.timedis(out)
        return out
    def forward(self, x):

        self.dim_seq = x.shape[1]
        main, aux = x, torch.unsqueeze(x[:,:,0], dim=2)
        mean, log_var = self.encode(main)
        std = torch.exp(0.5*log_var)
        q = torch.distributions.Normal(mean, std)
        z = q.rsample()
        out = self.decode(z, aux)
        return out, mean, log_var