PyTorch Adam vs Tensorflow Adam

Hi guys, long post incoming.

tl;dr PyTorch’s Adam has consistently worse performance for the exact same setting and by worse performance I mean PyTorch’s models cannot be used for this particular application.

Probably similar to this and this

Okay first a bit of background:
I have implemented Raissi et al. in both tf and pytorch.
The whole point is to get the network to approximate the solution to a PDE (1D Burger’s equation in this case). Both frameworks can approximate the solution but TF’s approximation is much better in that it can capture complex dynamics (i.e. the formation of a shock wave) while pytorch cannot meaning it probably cannot be used to solve more complex equations.
The results presented in the paper are reproducible in tf even when using a structure different from the original implementation. My TF implementation is different from the original one but it is identical to the pytorch one for comparison.

TF works out of the box while in pytorch I could not replicate the results even when trying a whole lot of different configurations (network architectures, optimizers, etc…)

Now for the experiments:
I have tried to make the results as comparable as possible doing the following:


  • Same hyperparameters for Adam (default ones in TF)
  • Same init (Xavier uniform)

When that didn’t work I went even further and:


  • Initialized weights in TF and loaded them in pytorch
  • Initialized weights in pytorch and loaded them in TF

In A TF’s results are competitive while pytorch’s are not. In B it gets interesting because TF converges to a good result with its own weights but not with pytorch’s while pytorch doesn’t converge with neither its own nor with TF’s weights.
When doing a forward pass in pytorch/TF with the weights loaded from TF/pytorch they give the exact same answer so loading the weights is not the problem. Further, the fact the pytorch approximates the right solution somehow means the network is correctly wired.

This is a typical loss plot where TF is in blue and pytorch in orange.

I can’t post more images because I’m a new user but in the plot of the solutions the TF solution approximates the discontinuity in the middle while pytorch can’t quite get there which makes me think it is an optimizer issue.

My ugly code for pytorch and TF below. Code is uncleaned but I’m here to clarify and discuss anything.

import numpy as np
import matplotlib.pyplot as plt
import torch

class PINN(torch.nn.Module):
    def __init__(self):
        self.dim_real_in = 2
        self.dim_img_in = 0
        self.dim_real_out = 1
        self.dim_img_out = 0
        self.architecture = [20, 20, 20, 20]
        self.activation_functions = ['tanh', 'tanh', 'tanh']
        self.torch_activation = {'tanh': torch.tanh}
        self.layers = self._get_layers()

    def _get_layers(self):
        self.architecture.insert(0, self.dim_real_in + self.dim_img_in)
        l = []
        print("Number of neurons and activation functions not equal. Using tanh for all")
        for i in range(len(self.architecture) - 1):
            layer = torch.nn.Linear(self.architecture[i], self.architecture[i + 1])
            super().add_module("layer"+ str(i), layer)

        layer = torch.nn.Linear(self.architecture[i + 1], self.dim_img_out + self.dim_real_out)
        super().add_module("layer" + str(i + 1), layer)
        return l

    def _initialize_weights(self):
        s = 0
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                # This bit loads weights from numpy arrays
                # The loaded weights have identical forward passes as in TF!
                # m.weight = torch.nn.Parameter(torch.from_numpy(np.load(str(s) + '_trained.npy')).T)
                # s +=1
                # m.bias = torch.nn.Parameter(torch.from_numpy(np.load(str(s) + '_trained.npy')))
                # s +=1
                torch.nn.init.xavier_uniform_(m.weight, gain=5/3)
                torch.nn.init.constant_(m.bias, 0)
                # This saves the weight init
                # + '_torch.npy', m.weight.detach().numpy())
                # s+=1
                # + '_torch.npy', m.bias.detach().numpy())
                # s+=1

    def forward(self, t, x):
        t.requires_grad = True
        x.requires_grad = True
        var =, x), dim=1)
        for l in self.layers:
            var = l(var)

        H = var
        H_to_dif = H.sum()
        dt, = torch.autograd.grad(H_to_dif, t, create_graph=True)
        dx, = torch.autograd.grad(H_to_dif, x, create_graph=True)
        dxx, = torch.autograd.grad(dx.sum(), x, retain_graph=True)

        return H, dt, dx, dxx

class DataWrapper(
    def __init__(self, data, labels):
        super().__init__() = data
        self.labels = labels

    def __getitem__(self, index):
        return[index], self.labels[index]

    def __len__(self):

class Schrodinger:
    def __init__(self, targets):

        self.model = self._build_model()
        self.targets = targets

    def _build_model(self):
        model = PINN().to('cuda')
        model = model.double()
        return model

    def physics(self,t ,x):
        u, u_t, u_x, u_xx = self.model.forward(t, x)

        f = u_t + u * u_x - 0.01 * u_xx
        return f, u

    def loss(self, t, x, labels):

        initial_mask = (t[:,0] == 0)
        boundary_mask = ((x[:,0] == -1) | (x[:,0] == 1)) & (t[:,0] != 0)
        structure_mask = ~ (initial_mask | boundary_mask)

        f, u = self.physics(t, x)
        l = torch.zeros_like(u)

        structure_norm = structure_mask.sum()
        boundary_norm = boundary_mask.sum()
        initial_norm = initial_mask.sum()

        l[structure_mask] = f[structure_mask]
        l[boundary_mask] = u[boundary_mask]
        l[initial_mask] = u[initial_mask]

        labels = labels.reshape(-1,1)
        l = (l - labels)**2

        l[structure_mask] = l[structure_mask] /structure_norm
        l[boundary_mask] = l[boundary_mask] / boundary_norm
        l[initial_mask] = l[initial_mask] / initial_norm

        self.structure_loss = [l[structure_mask].sum(), structure_norm]
        self.boundary_loss = [l[boundary_mask].sum(), boundary_norm]
        self.initial_loss = [l[initial_mask].sum(), initial_norm]

        l = l.sum()
        return l

    def train(self, t, x, epochs):
        self.loss_histogram = []
        self.weight_histogram = []
        optimizer = torch.optim.Adam(self.model.parameters(),
                                     lr=1e-3, eps=1e-07, weight_decay=0, amsgrad=False)
        # optimizer = torch.optim.LBFGS(self.model.parameters(), lr=0.1)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)
        print("Training params: ", len(list(self.model.parameters())))
        for i in range(epochs):
            w = 0
            self.prt = True
            def closure():

                loss = self.loss(t, x, self.targets[:, 2])
                if i % 100 == 0 and self.prt:
                    self.prt = False
                    print("Epoch: {:1.0f}, loss: {:1.5f} ".format(i,
                    print("struc: {:1.5f}, n: {:1.0f}, bdry: {:1.5f}, n: {:1.0f}, initial: {:1.5f}, n: {:1.0f}".format(
                        self.structure_loss[0].item(), self.structure_loss[1].item(),
                        self.boundary_loss[0].item(), self.boundary_loss[1].item(),
                        self.initial_loss[0].item(), self.initial_loss[1].item()))
                return loss

            # scheduler.step()
            for m in self.model.modules():
                if isinstance(m, torch.nn.Linear):
                    w += (m.weight**2).sum()
        np.savetxt('pytorch.txt', self.loss_histogram)


if __name__ == "__main__":
    from itertools import product

    shape_x = 200
    shape_t = 200
    X = np.linspace(-1, 1, shape_x)
    T = np.linspace(0, 1, shape_t)
    Z = np.zeros(shape_t*shape_x)

    prod = product(T, X)
    prod = np.array(list(prod))
    prod = np.insert(prod, 2 ,0, axis=1)

    sin_form = - np.sin(np.pi*X)
    prod[prod[:,0] == 0, 2] = sin_form
    prod[(prod[:,0] == -1) | (prod[:,0] == 1), 2] = 0

    x_coord, t_coord = [], []
    for t, x, z in prod:
    X = np.array(x_coord).reshape(-1, 1)
    T = np.array(t_coord).reshape(-1, 1)

    dtype = torch.float64

    X_tensor = torch.tensor(X, dtype=dtype).to('cuda')
    T_tensor = torch.tensor(T, dtype=dtype).to('cuda')
    total_data = torch.tensor(prod, dtype=dtype).to('cuda')
    point_t = torch.tensor([[2]], dtype=dtype).to('cuda')
    point_x = torch.tensor([[2]], dtype=dtype).to('cuda')

    model = Schrodinger(total_data)
    model.train(T_tensor, X_tensor, 1000)
    out, _ , _, _ = model.model.forward(T_tensor, X_tensor)

    # Compare forward with forward from TF. If using same weights error = 0
    # out_tf = np.load('output_tf.npy')
    # error = (out.detach().cpu() - out_tf)**2
    # error = error.sum()
    # print(error)
    im = out.reshape(shape_t, shape_x).detach().cpu().T

    for i in range(0, im.shape[1], 50):
        plt.plot(prod[prod[:,0] == 0,1], im[:, i])


import numpy as np
import tensorflow as tf
import itertools
import matplotlib.pyplot as plt

class PINN:
    def __init__(self, layers, f_x, g_t, alpha):

        self.initial_condition = f_x
        self.boundary_condition = g_t
        self.layers = layers
        self.alpha = alpha

        self.x = tf.placeholder(tf.float64, shape=[None, 1])
        self.t = tf.placeholder(tf.float64, shape=[None, 1])

        self.num_points_space = 200
        self.num_points_time = 200

    def mlp(self, t, x, variable_scope='default'):

        output = tf.concat([t, x], axis=1)

        # Pass through MLP
        with tf.variable_scope(variable_scope, reuse=tf.AUTO_REUSE):
            for i, size in enumerate(self.layers):
                output = tf.nn.tanh(tf.layers.dense(output, size, name="layer_{}".format(i)))

            # Last layer without a tanh
            output = tf.layers.dense(output, 1, name="layer_{}".format(i + 1))

        return output

    def gen_data(self):
        X = np.linspace(-1, 1, self.num_points_space)
        T = np.linspace(0, 1, self.num_points_time)

        prod = itertools.product(T, X)
        x_coord, t_coord = [], []
        for t, x in prod:
        X = np.array(x_coord).reshape(-1, 1)
        T = np.array(t_coord).reshape(-1, 1)

        return T, X

    def PDE_constraint(self, t, x):

        u = self.mlp(t, x)
        u_t = tf.gradients(u, t, unconnected_gradients='zero')[0]
        u_x = tf.gradients(u, x, unconnected_gradients='zero')[0]
        u_xx = tf.gradients(u_x, x, unconnected_gradients='zero')[0]
        f = u_t + u*u_x - self.alpha*u_xx

        return f

    def loss(self, t, x):

        L_f = tf.square(self.PDE_constraint(t, x)) #- self.poisson_condition(t, x))
        L_u_time = tf.square(self.mlp(t, x) - self.boundary_condition(t))
        L_u_space = tf.square(self.mlp(t, x) - self.initial_condition(x))

        cond = ~tf.equal(t, 0) & ~tf.equal(x, 1) & ~tf.equal(x, -1)
        normalization_f = tf.reduce_sum(tf.cast(cond, tf.float64))
        L_f = tf.where(cond, L_f, tf.zeros_like(L_f))

        cond = (tf.equal(x, 1) | tf.equal(x, -1)) & ~tf.equal(t, 0)
        normalization_u_time = tf.reduce_sum(tf.cast(cond, tf.float64))
        L_u_time = tf.where(cond, L_u_time, tf.zeros_like(L_u_time))

        cond = tf.equal(t, 0)
        normalization_u_space = tf.reduce_sum(tf.cast(cond, tf.float64))
        L_u_space = tf.where(cond, L_u_space, tf.zeros_like(L_u_space))

        return tf.reduce_sum(L_f)/normalization_f + \
               tf.reduce_sum(L_u_space)/normalization_u_space + \

    def detailed_loss(self, t, x):

        L_f = tf.square(self.PDE_constraint(t, x)) #- self.poisson_condition(t, x))
        L_u_time = tf.square(self.mlp(t, x) - self.boundary_condition(t))
        L_u_space = tf.square(self.mlp(t, x) - self.initial_condition(x))

        cond = ~tf.equal(t, 0) & ~tf.equal(x, 1) & ~tf.equal(x, -1)
        normalization_f = tf.reduce_sum(tf.cast(cond, tf.float64))
        L_f = tf.where(cond, L_f, tf.zeros_like(L_f))

        cond = (tf.equal(x, 1) | tf.equal(x, -1)) & ~tf.equal(t, 0)
        normalization_u_time = tf.reduce_sum(tf.cast(cond, tf.float64))
        L_u_time = tf.where(cond, L_u_time, tf.zeros_like(L_u_time))

        cond = tf.equal(t, 0)
        normalization_u_space = tf.reduce_sum(tf.cast(cond, tf.float64))
        L_u_space = tf.where(cond, L_u_space, tf.zeros_like(L_u_space))

        return tf.reduce_sum(L_f)/normalization_f, \
               tf.reduce_sum(L_u_space)/normalization_u_space, \
               tf.reduce_sum(L_u_time)/normalization_u_time, \
               normalization_f, normalization_u_space, normalization_u_time

    def train(self, training_iter):

        plot_loss = []
        T, X = self.gen_data()
        loss = self.loss(self.t, self.x)

        tf_dict = {self.t: T, self.x: X}

        optimizer_Adam = tf.train.AdamOptimizer()
        train_op_Adam = optimizer_Adam.minimize(loss)

        init = tf.global_variables_initializer()
        sess = tf.Session()
        # Code below loads weight inits
        # s = 0
        # for v in  tf.trainable_variables():
        #     w = np.load(str(s) + '_torch.npy').T
        #     # print(
        #     s += 1
        # Code below saves weight inits
        # for v in  tf.trainable_variables():
        #     s += 1

        # Saves the forward pass to compare it with pytorch's
        # im =, X))
        #'output_tf.npy', im)
        # im = np.load('output_tf.npy').reshape(self.num_points_space, self.num_points_time)
        # plt.imshow(im.T)

        for it in range(training_iter):
            _, l =[train_op_Adam, loss], tf_dict)


        np.savetxt('tf.txt', plot_loss)
        print("TOTAL LOSS: ", l)

        data1 =[T == 0, np.newaxis], T[T == 0, np.newaxis]))
        data2 =[T == 0.5, np.newaxis], T[T == 0.5, np.newaxis]))
        data3 =[T == 1, np.newaxis], T[T == 1, np.newaxis]))

        self.plot(data1, data2, data3)
        total_f =, self.t), tf_dict)
        L_f, L_space, L_time, normalization_f, normalization_u_space, normalization_u_time =, self.t), tf_dict)
        print("TOTAL PDE constraint violation: ", np.sum(np.square(total_f)))
        print("TOTAL f constraint violation: ", L_f)
        print("TOTAL time constraint violation: ", L_time)
        print("TOTAL space constraint violation: ", L_space)
        print("NUMBER f constraint violation: ", normalization_f)
        print("NUMBER time constraint violation: ", normalization_u_time)
        print("NUMBER space constraint violation: ", normalization_u_space)

        im =, X)).reshape(self.num_points_space, self.num_points_time)
        s = 0
        for v in tf.trainable_variables():
   + "_trained",
            s += 1
        return sess

    def predict(self,sess, t, x, shape_x, shape_t):
        pred =, x))
        im = pred.reshape(shape_x, shape_t).T
        for i in range(0, im.shape[1], 50):
            plt.plot(im[:, i])


    def plot(self, data1, data2, data3):

def zero(t):
    return tf.zeros_like(t)

def shifted_sin(x):
    return -tf.sin(np.pi * x)

if __name__ == "__main__":

    shape_x = 500
    shape_t = 500
    X = np.linspace(-1, 1, shape_x)
    T = np.linspace(0, 1, shape_t)
    prod = itertools.product(T, X)
    x_coord, t_coord = [], []
    for t, x in prod:
    X = np.array(x_coord).reshape(-1, 1)
    T = np.array(t_coord).reshape(-1, 1)
    architecture = [20, 20, 20, 20]
    PINN = PINN(architecture, shifted_sin, zero, alpha=0.01)
    sess = PINN.train(1000)
    PINN.predict(sess, T, X, shape_x, shape_t)



I can’t edit or delte the post for some reason. Anyway after some more reproducible experiments I fixed a few bugs in my code and now the two optimizers give almost the same result. This is a reasonably small discrepancy so I’m retracting the the post :slight_smile:

TF: 0.005030287380910818
PT: 0.0051292076167447093

Thanks for the update and good to hear you’ve solved the issue. :slight_smile:

Hi Omer, I am working on something very similar currently (using a modified LSTM network to solve some PDE’s, one of which is Burger’s equation) and I too am struggling to get my pytorch network to converge inline with an identical tf network. If you are still able to recall the problem above it would be great if you could share how you solved the problem in the end, specifically what were the ‘bugs’ in the code and how you were able to fix them.

Many thanks in advance (apologies for any poor grammar or formatting, this is my first post here).

So I didn’t want to be one of “those guys” who say nvm I solved the problem but in this case my I think that it was a lot of small bugs in my own implementation and I didn’t think anyone could benefit from it.

That said, the most important change I made was to change
dxx, = torch.autograd.grad(dx.sum(), x, retain_graph=True) to
dxx, = torch.autograd.grad(dx.sum(), x, create_graph=True)

This is important because otherwise dxx is computed but not added to the graph i.e. it is treated as a constant when doing the backprob (derivation wrt weights). Without the dxx term Burgers equation reduces to the inviscid Burgers equation which is what happend to me!

Other than that if I learned one thing while working on this is that training PINNs (I’m guessing you are woking on something similar) is extremely tricky.

PD: as the second derivative of ReLU is 0, when your equation has a dxx somewhere it usually fails to learn it! Took me a while to figure that one out.

Hope this helps.

Edit: So I can’t send you a pm because I’m still a new user u.u I’m still working on this problem and I’ll be happy to talk about it if you want. I also have the working code for this (beware, it is not clean!) that I can share.


That is excellent! I have an almost identical line of code in my work (I am using deep galerkin methods, which, I believe, live in a neighbourhood not too far away from PINN’s). If you are happy to talk about it in more detail that would be brilliant as I’ve only just started down this road and any insights you can offer would be massively appreciated.

If you are happy to talk about it (I am a new user also so can’t PM) please just drop me an email here:

It’s a throwaway but I’ll give you my work email directly if that’s ok (just to spare posting my email publicly). Thank you for your fast reply and heavy detail!

Wow. Thank you very much! Who would have thought? I just googled “Adam optimizer, Pytorch vs Tensorflow” and found this. Turns out I made the same mistake as well (a different application but I also need to set creat_graph=True). I guess it is kinda a universal problem that are easy to miss. If I understand Pytorch more thoroughly I would have known but there is no way I can catch this problem in a short period of time without reading your post. Well guess we all need to calculate high-order derivatives sometimes :slight_smile:

This helped me a lot! Thank you, Omer.
I am also translating a physics informed neural network (PINN) from Tensorflow (1.0) to PyTorch and struggled with some syntaxes as well.
At the moment my script (based on this notebook by Perdikaris) is running in PyTorch and gives similar results, however I still have some questions regarding my coding. My script is waaayyyy slower than Tensorflow (Tf trains in about 1 minute, but Torch in 40 minutes(!!) with the same amount of iterations and datapoints). Have you encountered this problem as well? I figured out that adding the create_graphs in torch.autograd.grad(dx.sum(), x, create_graph=True) slows down the program. Of course, without this, the program won’t run properly. Also the behaviour of the losses is a bit different as can be seen in the pictures (left is Tensorflow, right PyTorch).

Do you have any idea how to upgrade the PyTorch code such that the training time will be similar to the training time of the Tensorflow model? If required, I could show (some parts of) the code.


I just found your post, and given the subject matter it is none too long. In particular your followup with the amendment was very beneficial.

This will help a great deal; btw I reran the PyTorch script, both with Adam and BFGS optimizers and they work very well (quite fast even on my 2016 vintage Asus).

Going forward I will send more feedback as I go.

Thanks for sharing this material,
