Pytorch and tensorflow different results

I am interested in whether the same models in pytorch and tensorflow that undergo the same training should have the same results and should intermediate models and loss during every training iteration be the same. If not, what is causing this difference?

Below is my code where I have a single layer pytorch and tensorflow model, but am getting different results (assertion errors). Why is this happening? What is difference is small enough to be a reasonable discrepancy (this usually varies according to use different use cases)?

# %%
# imports
import tensorflow as tf
import torch
from torch import nn
from itertools import count
import time

# %%
torch.set_default_dtype(torch.float32)
POLY_DEGREE = 4
W_target = torch.randn(POLY_DEGREE, 1) * 5
b_target = torch.randn(1) * 5

# %% [markdown]
# ## Initialize Models

# %% [markdown]
# ### TensorFlow model

# %%
tf_model = tf.keras.Sequential()
tf_model.add(tf.keras.layers.Dense(1, input_shape=(W_target.size(0),), activation=None))
tf_weights = tf_model.get_weights()
tf_weights

# %% [markdown]
# ### PyTorch Model

# %%
torch_model = nn.Linear(in_features=W_target.size(0), out_features=1)
torch_weights = torch_model.state_dict()
torch_weights

# %%
# copy over weights from tf_model to pytorch_model
new_weights = {}

for idx, key in enumerate(torch_weights.keys()):
    if key.endswith('weight'):
        new_weights[key] = torch.Tensor(tf_weights[idx].T)
    else:
        new_weights[key] = torch.Tensor(tf_weights[idx])


torch_model.load_state_dict(new_weights)
torch_model.state_dict()

# %% [markdown]
# ## Define some helper functions

# %%
# from pytorch examples/regression
def make_features(x):
    """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
    x = x.unsqueeze(1)
    return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)


def f(x):
    """Approximated function."""
    return x.mm(W_target) + b_target.item()


def poly_desc(W, b):
    """Creates a string description of a polynomial."""
    result = 'y = '
    for i, w in enumerate(W):
        result += '{:+.2f} x^{} '.format(w, i + 1)
    result += '{:+.2f}'.format(b[0])
    return result


def get_batch(batch_size=32):
    """Builds a batch i.e. (x, f(x)) pair."""
    random = torch.randn(batch_size)
    x = make_features(random)
    y = f(x)
    return x, y

# %% [markdown]
# ## Train Models

# %%
# check that initial outputs are the same
# x, y = get_batch(batch_size=1)
# print("tf model output:", tf_model.predict(x.numpy()))
# print("pytorch model output:", torch_model(x))

# %%
# using MSE loss fn and SGD optim for both models
torch_loss = nn.MSELoss()
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)

tf_optim = tf.keras.optimizers.Adam(learning_rate=1e-3)
tf_loss = tf.keras.losses.MeanSquaredError()
tf_model.compile(tf_optim, loss=tf_loss)

# %%
for batch_idx in count(1):
    print("starting batch")
    # Get data
    batch_x, batch_y = get_batch(batch_size=32)

    """Train TensorFlow Model"""
    # tf_model.fit(
    #     x=batch_x.numpy(),
    #     y=batch_y.numpy(),
    #     batch_size=32,
    #     shuffle=False,
    #     epochs=1
    # )
    tf_dict = tf_model.train_on_batch(
        x=batch_x.numpy(),
        y=batch_y.numpy(),
        return_dict=True
    )

    """Train PyTorch Model"""
    torch_optim.zero_grad()
    y_preds = torch_model(batch_x)
    torch_fit_loss = torch_loss(y_preds, batch_y)
    torch_fit_loss.backward()
    torch_optim.step()

    # check to see if they have same losses
    tf_fit_loss = torch.Tensor([tf_dict["loss"]])
    print("TF Loss:", tf_fit_loss)
    print("PyTorch Loss:", torch_fit_loss)
    print("TF Loss with PyTorch pred:", tf_loss(batch_y.numpy(), y_preds.detach().numpy()))
    print(torch.allclose(tf_fit_loss, torch_fit_loss), "diff:", tf_fit_loss-torch_fit_loss)
    assert torch.allclose(tf_fit_loss, torch_fit_loss), f"losses are not the same -- tf_pred:{tf_fit_loss}, torch_pred:{torch_fit_loss}, diff:{tf_fit_loss-torch_fit_loss}"


    # check tf and pytorch model have the same weights
    torch_weights = torch_model.state_dict()
    tf_weights = tf_model.get_weights()
    print(torch_weights)
    print(tf_weights)
    x, y = get_batch(batch_size=1)
    tf_pred, torch_pred = torch.from_numpy(tf_model.predict(x.numpy())), torch_model(x)
    assert torch.allclose(tf_pred, torch_pred), f"predictions are not the same -- tf_pred:{tf_pred}, torch_pred:{torch_pred}, diff:{tf_pred-torch_pred}"
    for idx, key in enumerate(torch_weights.keys()):
        if key.endswith('weight'):
            print("(weights) torch_weights[key]:", torch_weights[key])
            print("(weights) torch.Tensor(tf_weights[idx].T)):", torch.Tensor(tf_weights[idx].T))
            assert torch.allclose(torch_weights[key], torch.Tensor(tf_weights[idx].T)), f"weights are not equal. differences: {torch_weights[key] - torch.Tensor(tf_weights[idx].T)}"
            print("weights are equal")
        else:
            print("(bias) torch_weights[key]:", torch_weights[key])
            print("(bias) torch.Tensor(tf_weights[idx])):", torch.Tensor(tf_weights[idx]))
            assert torch.allclose(torch_weights[key], torch.Tensor(tf_weights[idx])), f"biases are not equal. differences: {torch_weights[key] - torch.Tensor(tf_weights[idx])}"

    print("batch finished")
    # Stop criterion
    if torch_fit_loss < 1e-3:
        break