Model works well on tensorflow but pytorch implementation does gets stuck

I have this tensorflow code that uses ragged tensor data

import tensorflow as tf
import keras.backend as K

class Baseline_cbr_mb(tf.keras.Model):
    mean_std_scores_fields = {
        "flow_traffic",
        "flow_packets",
        "flow_pkts_per_burst",
        "flow_bitrate_per_burst",
        "flow_packet_size",
        "flow_p90PktSize",
        "rate",
        "flow_ipg_mean",
        "ibg",
        "flow_ipg_var",
        "link_capacity",
    }
    mean_std_scores = None

  

    def __init__(self, override_mean_std_scores=None, name=None):
        super(Baseline_cbr_mb, self).__init__()

        self.iterations = 12
        self.path_state_dim = 16
        self.link_state_dim = 16

        if override_mean_std_scores is not None:
            self.set_mean_std_scores(override_mean_std_scores)
        if name is not None:
            assert type(name) == str, "name must be a string"
            self.name = name

        self.attention = tf.keras.Sequential(
            [tf.keras.layers.Input(shape=(None, None, self.path_state_dim)),
            tf.keras.layers.Dense(
                self.path_state_dim, activation=tf.keras.layers.LeakyReLU(alpha=0.01)    
            ),
            ]
        )

            # GRU Cells used in the Message Passing step
        self.path_update = tf.keras.layers.RNN(
            tf.keras.layers.GRUCell(self.path_state_dim, name="PathUpdate",
            ),
            return_sequences=True,
            return_state=True,
            name="PathUpdateRNN",
        )
        self.link_update = tf.keras.layers.GRUCell(
            self.link_state_dim, name="LinkUpdate",
        )

        self.flow_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=13),
                tf.keras.layers.Dense(
                    self.path_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(
                    self.path_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    )
            ],
            name="PathEmbedding",
        )

        self.link_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=3),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    )
            ],
            name="LinkEmbedding",
        )

        self.readout_path = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=(None, self.path_state_dim)),
                tf.keras.layers.Dense(
                    self.link_state_dim // 2, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(
                    self.link_state_dim // 4, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(1, activation=tf.keras.activations.softplus)
            ],
            name="PathReadout",
        )
    
    def set_mean_std_scores(self, override_mean_std_scores):
        assert (
            type(override_mean_std_scores) == dict
            and all(kk in override_mean_std_scores for kk in self.mean_std_scores_fields)
            and all(len(val) == 2 for val in override_mean_std_scores.values())
        ), "overriden mean-std dict is not valid!"
        self.mean_std_scores = override_mean_std_scores

    @tf.function
    def call(self, inputs):
        # Ensure that the min-max scores are set
        assert self.mean_std_scores is not None, "the model cannot be called before setting the min-max scores!"

        # Process raw inputs
        flow_traffic = inputs["flow_traffic"]
        flow_packets = inputs["flow_packets"]
        global_delay = inputs["global_delay"]
        global_losses = inputs["global_losses"]
        max_link_load = inputs["max_link_load"]
        flow_pkt_per_burst = inputs["flow_pkts_per_burst"]
        flow_bitrate = inputs["flow_bitrate_per_burst"]
        flow_packet_size = inputs["flow_packet_size"]
        flow_type = inputs["flow_type"]
        flow_ipg_mean = inputs["flow_ipg_mean"]
        flow_length = inputs["flow_length"]
        ibg = inputs["ibg"]
        flow_p90pktsize = inputs["flow_p90PktSize"]
        cbr_rate = inputs["rate"]
        flow_ipg_var = inputs["flow_ipg_var"]
        link_capacity = inputs["link_capacity"]
        link_to_path = inputs["link_to_path"]
        path_to_link = inputs["path_to_link"]

        flow_pkt_size_normal = (flow_packet_size - self.mean_std_scores["flow_packet_size"][0]) \
                    * self.mean_std_scores["flow_packet_size"][1],

        path_gather_traffic = tf.gather(flow_traffic, path_to_link[:, :, 0])
        load = tf.math.reduce_sum(path_gather_traffic, axis=1) / (link_capacity * 1e9)
        normal_load = tf.math.divide(load, tf.squeeze(max_link_load))
        
        # Initialize the initial hidden state for paths
        path_state = self.flow_embedding(
            tf.concat(
                [
                    (flow_traffic - self.mean_std_scores["flow_traffic"][0])
                    * self.mean_std_scores["flow_traffic"][1],
                    (flow_packets - self.mean_std_scores["flow_packets"][0])
                    * self.mean_std_scores["flow_packets"][1],
                    (ibg - self.mean_std_scores["ibg"][0])
                    * self.mean_std_scores["ibg"][1],
                    (cbr_rate - self.mean_std_scores["rate"][0])
                    * self.mean_std_scores["rate"][1],
                    (flow_p90pktsize - self.mean_std_scores["flow_p90PktSize"][0])
                    * self.mean_std_scores["flow_p90PktSize"][1],
                    (flow_packet_size - self.mean_std_scores["flow_packet_size"][0])
                    * self.mean_std_scores["flow_packet_size"][1],
                    (flow_bitrate - self.mean_std_scores["flow_bitrate_per_burst"][0])
                    * self.mean_std_scores["flow_bitrate_per_burst"][1],
                    (flow_ipg_mean - self.mean_std_scores["flow_ipg_mean"][0])
                    * self.mean_std_scores["flow_ipg_mean"][1],
                    (flow_ipg_var - self.mean_std_scores["flow_ipg_var"][0])
                    * self.mean_std_scores["flow_ipg_var"][1],
                    (flow_pkt_per_burst - self.mean_std_scores["flow_pkts_per_burst"][0])
                    * self.mean_std_scores["flow_pkts_per_burst"][1],
                    tf.expand_dims(tf.cast(flow_length, dtype=tf.float32), 1),
                    flow_type
                ],
                axis=1,
            )
        )

        # Initialize the initial hidden state for links
        link_state = self.link_embedding(
            tf.concat(
                [
                   (link_capacity - self.mean_std_scores["link_capacity"][0])
                    * self.mean_std_scores["link_capacity"][1],
                    load,
                    normal_load,
                ],
                axis=1,
            ),
        )
        

        # Iterate t times doing the message passing
        for _ in range(self.iterations):
            ####################
            #  LINKS TO PATH   #
            ####################
            
            link_gather = tf.gather(link_state, link_to_path, name="LinkToPath")

            previous_path_state = path_state
            path_state_sequence, path_state = self.path_update(
                link_gather, initial_state=path_state
            )
            
            # We select the element in path_state_sequence so that it corresponds to the state before the link was considered
            path_state_sequence = tf.concat(
                [tf.expand_dims(previous_path_state, 1), path_state_sequence], axis=1
            )
            
            ###################
            #   PATH TO LINK  #
            ###################
            path_gather = tf.gather_nd(
                path_state_sequence, path_to_link, name="PathToLink"
            )
            
            attention_coef = self.attention(path_gather)
            normalized_score = K.softmax(attention_coef)
            weighted_score = normalized_score * path_gather
            
            path_gather_score = tf.math.reduce_sum(weighted_score, axis=1)
            
            link_state, _ = self.link_update(path_gather_score, states=link_state)

        ################
        #  READOUT     #
        ################

        occupancy = self.readout_path(path_state_sequence[:, 1:])

        capacity_gather = tf.gather(link_capacity, link_to_path)
        
        queue_delay = occupancy / capacity_gather
        queue_delay = tf.math.reduce_sum(queue_delay, axis=1)

        return queue_delay

I created a pytorch implementation but it gets stuck at 20% wors off than the tf model. I have checked for weeks but cant figure out why

import torch
import torch.nn.functional as F

import time
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import init
class PathEmbedding(nn.Module):
    def __init__(self, path_state_dim):
        super(PathEmbedding, self).__init__()
        self.path_state_dim = path_state_dim
        self.flow_embedding = nn.Sequential(
            nn.Linear(13, self.path_state_dim),
            nn.GELU(),
            nn.Linear(self.path_state_dim, self.path_state_dim),
            nn.GELU()
        )
        self.initialize_weights()

    def initialize_weights(self):
        for layer in self.flow_embedding:
            if isinstance(layer, nn.Linear):
                # Use LeCun (Xavier) uniform initialization
                init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('selu'))
                if layer.bias is not None:
                    init.constant_(layer.bias, 0)
    def forward(self, x):
        return self.flow_embedding(x)

class LinkEmbedding(nn.Module):
    def __init__(self, link_state_dim):
        super(LinkEmbedding, self).__init__()
        self.link_state_dim = link_state_dim
        self.link_embedding = nn.Sequential(
            nn.Linear(3, self.link_state_dim),
            nn.SELU(),
            nn.Linear(self.link_state_dim, self.link_state_dim),
            nn.SELU()
        )
        self.initialize_weights()

    def initialize_weights(self):
        for layer in self.link_embedding:
            if isinstance(layer, nn.Linear):
                # Use LeCun (Xavier) uniform initialization
                init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('selu'))
                if layer.bias is not None:
                    init.constant_(layer.bias, 0)

    def forward(self, x):
        return self.link_embedding(x)

class PathReadout(nn.Module):
    def __init__(self, path_state_dim, link_state_dim):
        super(PathReadout, self).__init__()
        self.path_state_dim = path_state_dim
        self.link_state_dim = link_state_dim
        self.readout_path = nn.Sequential(
            nn.Linear(self.path_state_dim, self.link_state_dim // 2),
            nn.SELU(),
            nn.Linear(self.link_state_dim // 2, self.link_state_dim // 4),
            nn.SELU(),
            nn.Linear(self.link_state_dim // 4, 1),
            nn.Softplus()
        )
        self.initialize_weights()

    def initialize_weights(self):
        for layer in self.readout_path:
            if isinstance(layer, nn.Linear):
                # Use LeCun (Xavier) uniform initialization
                init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('selu'))
                if layer.bias is not None:
                    init.constant_(layer.bias, 0)
    def forward(self, x):
        return self.readout_path(x)
class MyCombinedModel(nn.Module):
    def __init__(self, link_state_dim, path_state_dim, ):
        super(MyCombinedModel, self).__init__()
        self.link_embed = LinkEmbedding(link_state_dim)  # Assuming LinkEmbedding is defined
        self.flow_embedded = PathEmbedding(path_state_dim)  # Assuming PathEmbedding is defined
        self.reader = PathReadout(path_state_dim, link_state_dim)  # Assuming PathReadout is defined
        self.link_update = nn.GRUCell(
          link_state_dim, link_state_dim
        )

        self.attention = nn.Sequential(
            nn.Linear(path_state_dim, path_state_dim),
            nn.LeakyReLU(negative_slope=0.01)
        )

        # self.path_update = nn.GRUCell(
        #   path_state_dim, path_state_dim
        # )
        self.path_update  = nn.GRU(path_state_dim, path_state_dim, batch_first=True)

    def forward(self, flow_traffic_unnormalized, link_to_path, path_to_link, link_capacity, all_flows_used, flow_traffic, link_capacity_orig, max_link_load):
        path_state = self.flow_embedded(all_flows_used)
        # print(path_to_link[0], 'and ',torch.select(path_to_link, 2, 0))



        loads = []

        for i in range(len(path_to_link)):

            loads.append(torch.sum(torch.index_select(torch.tensor(flow_traffic_unnormalized), 0, torch.tensor(path_to_link[i][:, 0 ]).type(torch.int64)))/  (link_capacity_orig[i] * 1e9))

        normal_load  = torch.divide(torch.tensor(loads).unsqueeze(1), torch.tensor(max_link_load))

        loads_capacities = torch.cat((torch.tensor(loads).unsqueeze(1), normal_load, torch.tensor(link_capacity)), dim=1)
        # loads_capacities = torch.cat((all_link_capacities_used.unsqueeze(1), link_capacity), dim=1)
        link_state = self.link_embed(loads_capacities)
        path_state = torch.unsqueeze(path_state, axis=0)
        for _ in range(12):
            link_gather = [link_state[link_to_pat] for link_to_pat in link_to_path]
            link_gather = torch.nested.nested_tensor(link_gather)


            # print("path_state= ", torch.nested.to_padded_tensor(link_gather, 0.0).shape, .shape)
            path_state_sequence, path_state = self.path_update(torch.nested.to_padded_tensor(link_gather, 0.0),path_state )
            # print("got path state", path_state_sequence[0].shape,path_state_sequence[1].shape )
            prev_path_state = path_state
            path_gather = []
            path_gather_sum = []
            for path_to_lin in path_to_link:
              path_g_now = path_state_sequence[path_to_lin[:, 0], path_to_lin[:, 1]]
              path_g_now = path_g_now * F.softmax(self.attention(path_g_now), dim= 0)
              path_gather_sum.append(torch.sum(path_g_now, dim=0))





            path_sum = torch.stack(path_gather_sum)

            link_state = self.link_update(path_sum, link_state)

        capacity_gather = [link_capacity_orig[values] for values in link_to_path]

        delays = []
        for i, path_state_seq in enumerate(path_state_sequence):
            occupancy = self.reader(path_state_seq[1:, :])
            delays.append(torch.sum(occupancy[:len(capacity_gather[i])] / torch.tensor(capacity_gather[i]).T[0]).unsqueeze(0))

        return torch.cat(delays)


link_state_dim = 16
path_state_dim = 16

from torch.optim.lr_scheduler import ReduceLROnPlateau
num_layers = 2  # Example value for the number of layers
model = MyCombinedModel(link_state_dim,  path_state_dim)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
criterion = MeanAbsolutePercentageError()

def train():
  model.train()
  # fake_batch_size= 1
  with tqdm(train_dataset, total=len(train_dataset)) as pbar:
    predictions = []
    y_all = []
    # loss = 0
    # loss = torch.tensor([0.0])

    for i,batch in enumerate(pbar):

        data, labels = batch  #

        devices = data["devices"]
        flow_traffic_unnormalized = data['flow_traffic'].numpy()
        flow_traffic = (data['flow_traffic'].numpy() - normalizers["flow_traffic"][0]) *normalizers["flow_traffic"][1]
        flow_packets = (data['flow_packets'].numpy() - normalizers["flow_packets"][0]) *normalizers["flow_packets"][1]
        max_link_load = data["max_link_load"].numpy()
        flow_pkt_per_burst = (data['flow_pkts_per_burst'].numpy() - normalizers["flow_pkts_per_burst"][0]) *normalizers["flow_pkts_per_burst"][1]
        flow_bitrate = (data['flow_bitrate_per_burst'].numpy() - normalizers["flow_bitrate_per_burst"][0]) *normalizers["flow_bitrate_per_burst"][1]
        flow_packet_size = (data['flow_packet_size'].numpy() - normalizers["flow_packet_size"][0])  * normalizers["flow_packet_size"][1]
        flow_type = data['flow_type'].numpy()
        flow_ipg_mean = (data['flow_ipg_mean'].numpy() - normalizers["flow_ipg_mean"][0])  * normalizers["flow_ipg_mean"][1]
        flow_ipg_var = (data['flow_ipg_var'].numpy() - normalizers["flow_ipg_var"][0])  * normalizers["flow_ipg_var"][1]

        cbr_rate = (data['rate'].numpy() - normalizers["rate"][0]) *normalizers["rate"][1]

        flow_length = data['flow_length'].numpy()
        ibg = (data['ibg'].numpy() - normalizers["ibg"][0]) *normalizers["ibg"][1]
        flow_p90pktsize = (data['flow_p90PktSize'].numpy() - normalizers["flow_p90PktSize"][0]) *normalizers["flow_p90PktSize"][1]

        link_capacity = (data['link_capacity'].numpy()  - normalizers["link_capacity"][0]) *normalizers["link_capacity"][1]
        link_capacity_orig = data['link_capacity'].numpy()
        link_to_path = data['link_to_path'].numpy()
        path_to_link = data['path_to_link'].numpy()
        flow_pkt_size_normal = flow_packets


        y_vals = torch.tensor(labels)



        all_flows_used = torch.tensor(np.concatenate([flow_traffic, ibg, cbr_rate, flow_p90pktsize, flow_packets, flow_type,  flow_packet_size, flow_bitrate, flow_ipg_mean, flow_ipg_var, flow_pkt_per_burst,    np.expand_dims(flow_length,axis=1)], axis=1,dtype=np.float32 ) )

        # all_flows_used = torch.tensor(np.concatenate([flow_traffic,flow_packets, flow_type, flow_type, np.expand_dims(flow_length,axis=1)], axis=1,dtype=np.float32 ) )



        outputs = model(flow_traffic_unnormalized, link_to_path, path_to_link, link_capacity,all_flows_used, flow_traffic, link_capacity_orig,max_link_load)

        predictions.extend(outputs.detach().numpy())
        loss = criterion(outputs, y_vals ) # Compute the loss.
        # if i%fake_batch_size == 0 or i== len(train_dataset):

        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        # loss = torch.tensor([0.0])
        # print(outputs[:5],y_vals[:5] )
        pbar.set_description(f"Loss: {loss:.4f}  lr: {scheduler.optimizer.param_groups[0]['lr']}")
        y_all.extend(labels)
    print("train loss = ",criterion(torch.tensor(predictions), torch.tensor(y_all )).item())

Please can you help me figure out what the difference could be?