Difference in Pytorch and tensorflow version of Siamese Neural Network

I reconstructed a PyTorch version of a Siamese neural network that was originally written in TensorFlow, which includes transformer encoders. However, I noticed that my PyTorch model couldn’t achieve an AUC of 0.9 within n epochs, and the training loss stopped decreasing, while the TensorFlow version could reach 0.9 within the same number of epochs and continued to reduce the loss.

To troubleshoot, I suspected there might be differences in the architecture, but I ensured that the PyTorch model was identical to the TensorFlow version (at least from the coding perspective). I even checked the number of trainable parameters—in TensorFlow, it was 8,017,024, while in PyTorch, it was 7,742,496.

I’m confused about how this discrepancy in parameters is even possible. Do you have any insights on what could be going wrong? Is there anything wrong i am doing in the training of the model?

You could try to iterate the layers and compare trainable parameters between both implementations to narrow down where the discrepancy comes from.

I did what you suggested and found that the Embedding layers parameters are the same in both the versions however the encoder layer parameters are different. While torch gives a detailed count of the layers, tensorflow does not. I would like to mention a snippet from the output:
tensorflow:
Layer Name: embedding_5, Layer Type: <class ‘keras.src.layers.core.embedding.Embedding’>, Trainable Parameters: 6432896
Layer Name: functional_51, Layer Type: <class ‘keras.src.models.functional.Functional’>, Trainable Parameters: 396032
Layer Name: functional_53, Layer Type: <class ‘keras.src.models.functional.Functional’>, Trainable Parameters: 396032
Layer Name: functional_55, Layer Type: <class ‘keras.src.models.functional.Functional’>, Trainable Parameters: 396032
Layer Name: functional_57, Layer Type: <class ‘keras.src.models.functional.Functional’>, Trainable Parameters: 396032
Layer Name: dropout_29, Layer Type: <class ‘keras.src.layers.regularization.dropout.Dropout’>, Trainable Parameters: 0

torch:
Layer: embedding.weight, Parameters: 6432896
Layer: enc_layers.0.self_attn.in_proj_weight, Parameters: 49152
Layer: enc_layers.0.self_attn.in_proj_bias, Parameters: 384
Layer: enc_layers.0.self_attn.out_proj.weight, Parameters: 16384
Layer: enc_layers.0.self_attn.out_proj.bias, Parameters: 128
Layer: enc_layers.0.linear1.weight, Parameters: 65536
Layer: enc_layers.0.linear1.bias, Parameters: 512
Layer: enc_layers.0.linear2.weight, Parameters: 65536
Layer: enc_layers.0.linear2.bias, Parameters: 128
Layer: enc_layers.0.norm1.weight, Parameters: 128
Layer: enc_layers.0.norm1.bias, Parameters: 128
Layer: enc_layers.0.norm2.weight, Parameters: 128
Layer: enc_layers.0.norm2.bias, Parameters: 128
:
:
:
Note that the embedding layer parameters are the same i.e. 6,432,896 while the parameters per
encoder layer/functional layer differ. For tf it is 396,032 while for torch it adds upto 198,272.

I do not know what else to do or how to go about it.
I believe this can be a reason for the model not acheiving the required auc.

please also check if my training loop seems ok or not pytoch wise? I’m using a triplet loss function:
for i in range(1, n_iter + 1):

    network_train.train()
    
    # Step 8: Get batch of semi-hard triplets
    triplets = get_triplets(......)
    # Step 9: Forward pass and compute loss
    anchor = triplets[0].to(device)
    positive = triplets[1].to(device)
    negative = triplets[2].to(device)

    optimizer.zero_grad()
    loss = network_train(anchor, positive, negative)  # Forward pass
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights
    n_iteration += 1

   
    if i % evaluate_every == 0:
        print("\n ------------- \n")
        print(f"[{n_iteration}] Time for {i} iterations: {(time.time()-t_start)/60.0:.1f} mins, Train Loss: {loss.item()}")

        network_train.eval()
        with torch.no_grad():
            # Step 11: Compute probabilities and evaluate AUC
            probs, yprobs = compute_probs(..)
            fpr, tpr, thresholds, auc = compute_metrics(probs, yprobs)
            print("Validation AUC: ", auc)