TF and PyTorch gradients look different

Hi All,

I’m opening this draft because I’m developing a model which computes gradients of some outputs with respect to some inputs but I’m having different results based on the library used.
I have PyTorch and Tensorflow implementations with same initializations and they compute slightly different gradients.
The models are fully-connected with tanh nonlinearities

PyTorch (1.5.0):

# x_train, y_train and t_train are numpy arrays with shapes (5,)

x_t = torch.from_numpy(x_train).float()
y_t = torch.from_numpy(y_train).float()
t_t = torch.from_numpy(t_train).float()
x_t.requires_grad = True
y_t.requires_grad = True
t_t.requires_grad = True

X = torch.stack([x_t, y_t, t_t], dim=1)
lb, _ = X.min(0)
ub, _ = X.max(0)
H = 2.0*(X - lb)/(ub - lb) - 1.0

psi_and_p = model(H)
psi = psi_and_p.select(dim=1, index=0)
p = psi_and_p.select(dim=1, index=1)

dummy = torch.ones(len(X), requires_grad=False)
u = torch.autograd.grad(psi, y_t, grad_outputs=dummy, create_graph=True, only_inputs=True)[0]
v = - torch.autograd.grad(psi, x_t, grad_outputs=dummy, create_graph=True, only_inputs=True)[0]

The outcomes are
psi (first column of the output of the NN):

tensor([ 0.0577,  0.0858, -0.1348, -0.0789, -0.0031], grad_fn=<SelectBackward>)

u (gradient of psi wrt y):

tensor([ 0.0489, -0.0578,  0.0303, -0.0616,  0.0401], grad_fn=<UnbindBackward>)

v (negative gradient of psi wrt x):

tensor([-0.0788,  0.1062,  0.1445, -0.0767, -0.0951], grad_fn=<NegBackward>)

Tensorflow (1.15.3):

# function that takes the input and returns the outputs of the network
def neural_net(self, X, weights, biases):
    num_layers = len(weights) + 1
    
    H = 2.0*(X - self.lb)/(self.ub - self.lb) - 1.0
    for l in range(0,num_layers-2):
        W = weights[l]
        b = biases[l]
        H = tf.tanh(tf.add(tf.matmul(H, W), b))
    W = weights[-1]
    b = biases[-1]
    Y = tf.add(tf.matmul(H, W), b)
    return Y

# function that computes the gradients
def net_NS(self, x, y, t):
    psi_and_p = self.neural_net(tf.concat([x,y,t], 1), self.weights, self.biases)
    psi = psi_and_p[:,0:1]
    p = psi_and_p[:,1:2]
    
    u = tf.gradients(psi, y)[0]
    v = -tf.gradients(psi, x)[0]  

The outcomes are
psi (first column of the output of the NN):

[[ 0.05774451]
 [ 0.08581609]
 [-0.13483115]
 [-0.07894167]
 [-0.00312973]]

u (gradient of psi wrt y):

[[0.048948  ]
 [0.05040099]
 [0.03029619]
 [0.04972227]
 [0.04007644]]

v (negative gradient of psi wrt x):

[[-0.0788489 ]
 [-0.07497928]
 [-0.07125   ]
 [-0.07668343]
 [-0.0951325 ]]

The output is exactly the same while the gradients are different just for some entries

hi,

Where do you actually compute lb and ub in the tensorflow code?
Could you provide a code sample that we can run that reproduces the issue?

Hi Alban,

I wrote a simplified case that shows well this strange behavior. The prediction made by PyTorch matches perfectly the one from TF, but the gradients are different in some positions…
Essentially I’m trying to rewrite the following repo by using PyTorch:

Anyway, here is the simplified case (it requires tensorflow==1.15.3 and torch==1.5.0):

import numpy as np
import tensorflow as tf
import torch
from torch import nn
from torch.nn import functional as F
np.random.seed(1234)
tf.set_random_seed(1234)

# model structure
layers = [3, 20, 20, 20, 20, 20, 20, 20, 20, 2]

# input and output vectors
x_train = np.array([[5.73737374],
       [6.02020202],
       [3.19191919],
       [3.26262626],
       [4.18181818]])
y_train = np.array([[-0.28571429],
       [-1.02040816],
       [ 0.7755102 ],
       [ 1.51020408],
       [ 0.6122449 ]])
t_train = np.array([[17.2],
       [ 9.8],
       [13.4],
       [ 8.6],
       [ 1.5]])
u_train = np.array([[0.75602438],
       [0.7348674 ],
       [1.07155909],
       [1.16107105],
       [0.96089615]])
v_train = np.array([[ 0.42616224],
       [-0.25002337],
       [-0.06036023],
       [ 0.09509021],
       [-0.46543169]])

# Tensorflow
#-----------

class PhysicsInformedNN:
    # Initialize the class
    def __init__(self, x, y, t, u, v, layers):
        
        X = np.concatenate([x, y, t], 1)
        
        self.lb = X.min(0)
        self.ub = X.max(0)
                
        self.X = X
        
        self.x = X[:,0:1]
        self.y = X[:,1:2]
        self.t = X[:,2:3]
        
        self.u = u
        self.v = v

        self.layers = layers
        
        # Initialize NN
        self.weights, self.biases = self.initialize_NN(layers)        
        
        # Initialize parameters
        self.lambda_1 = tf.Variable([0.0], dtype=tf.float32)
        self.lambda_2 = tf.Variable([0.0], dtype=tf.float32)
        
        # tf placeholders and graph
        self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                     log_device_placement=False))
        
        self.x_tf = tf.placeholder(tf.float32, shape=[None, self.x.shape[1]])
        self.y_tf = tf.placeholder(tf.float32, shape=[None, self.y.shape[1]])
        self.t_tf = tf.placeholder(tf.float32, shape=[None, self.t.shape[1]])
        
        self.u_tf = tf.placeholder(tf.float32, shape=[None, self.u.shape[1]])
        self.v_tf = tf.placeholder(tf.float32, shape=[None, self.v.shape[1]])
        
        self.psi_pred, self.u_pred, self.v_pred, self.p_pred, self.f_u_pred, self.f_v_pred = self.net_NS(self.x_tf, self.y_tf, self.t_tf)
        
        self.loss = tf.reduce_sum(tf.square(self.u_tf - self.u_pred)) + \
                    tf.reduce_sum(tf.square(self.v_tf - self.v_pred)) + \
                    tf.reduce_sum(tf.square(self.f_u_pred)) + \
                    tf.reduce_sum(tf.square(self.f_v_pred))
        
        self.optimizer_Adam = tf.train.AdamOptimizer()
        self.train_op_Adam = self.optimizer_Adam.minimize(self.loss)                    
        
        init = tf.global_variables_initializer()
        self.sess.run(init)

    def initialize_NN(self, layers):        
        weights = []
        biases = []
        num_layers = len(layers) 
        for l in range(0,num_layers-1):
            W = self.xavier_init(size=[layers[l], layers[l+1]])
            b = tf.Variable(tf.zeros([1,layers[l+1]], dtype=tf.float32), dtype=tf.float32)
            weights.append(W)
            biases.append(b)        
        return weights, biases
        
    def xavier_init(self, size):
        in_dim = size[0]
        out_dim = size[1]        
        xavier_stddev = np.sqrt(2/(in_dim + out_dim))
        return tf.Variable(tf.truncated_normal([in_dim, out_dim], stddev=xavier_stddev), dtype=tf.float32)
    
    def neural_net(self, X, weights, biases):
        num_layers = len(weights) + 1
        
        H = 2.0*(X - self.lb)/(self.ub - self.lb) - 1.0
        for l in range(0,num_layers-2):
            W = weights[l]
            b = biases[l]
            H = tf.tanh(tf.add(tf.matmul(H, W), b))
        W = weights[-1]
        b = biases[-1]
        Y = tf.add(tf.matmul(H, W), b)
        return Y
        
    def net_NS(self, x, y, t):
        lambda_1 = self.lambda_1
        lambda_2 = self.lambda_2
        
        psi_and_p = self.neural_net(tf.concat([x,y,t], 1), self.weights, self.biases)
        psi = psi_and_p[:,0:1]
        p = psi_and_p[:,1:2]
        
        u = tf.gradients(psi, y)[0]
        v = -tf.gradients(psi, x)[0]  
        
        u_t = tf.gradients(u, t)[0]
        u_x = tf.gradients(u, x)[0]
        u_y = tf.gradients(u, y)[0]
        u_xx = tf.gradients(u_x, x)[0]
        u_yy = tf.gradients(u_y, y)[0]
        
        v_t = tf.gradients(v, t)[0]
        v_x = tf.gradients(v, x)[0]
        v_y = tf.gradients(v, y)[0]
        v_xx = tf.gradients(v_x, x)[0]
        v_yy = tf.gradients(v_y, y)[0]
        
        p_x = tf.gradients(p, x)[0]
        p_y = tf.gradients(p, y)[0]

        f_u = u_t + lambda_1*(u*u_x + v*u_y) + p_x - lambda_2*(u_xx + u_yy) 
        f_v = v_t + lambda_1*(u*v_x + v*v_y) + p_y - lambda_2*(v_xx + v_yy)
        
        return psi, u, v, p, f_u, f_v
    
    def callback(self, loss, lambda_1, lambda_2):
        print('Loss: %.3e, l1: %.3f, l2: %.5f' % (loss, lambda_1, lambda_2))
      
    def train(self, nIter): 
        
        tf_dict = {self.x_tf: self.x, self.y_tf: self.y, self.t_tf: self.t}

        print("X:", self.sess.run(tf.concat([self.x,self.y,self.t], axis=1)))
        print("Psi:", self.sess.run(self.psi_pred, tf_dict))
        print("U:", self.sess.run(self.u_pred, tf_dict))
        print("V:", self.sess.run(self.v_pred, tf_dict))
        
        w = self.sess.run(self.weights)
        
        return w

# training
model = PhysicsInformedNN(x_train, y_train, t_train, u_train, v_train, layers)
weights = model.train(200000)

# PyTorch
#--------

class PhysicsInformedNNTorch(nn.Module):
    """Implementation of a multilayer perceptron with Dropout in PyTorch.

    Parameters
    ----------
    units : list
        List of integers contaning the number of units from input to output.
    nfunc : str
        Name of the nonlinearity to be used in the network ('relu', 'elu').
    p : float
        Probability of an element to be zeroed by the Dropout's Bernoulli mask.
    param : float
        Extra parameter for activation functions.
    """
    
    def __init__(self, units, nfunc, weights):
        super(PhysicsInformedNNTorch, self).__init__()
        
        self.units = units
        
        # layers
        for i in range(1, len(self.units)):
            setattr(self, f"fc_{i-1}", nn.Linear(self.units[i-1], self.units[i], bias=True))
            getattr(self, f"fc_{i-1}").weight = torch.nn.Parameter(
                torch.from_numpy(
                    weights[i-1].T
                ).float()
            )
            nn.init.constant_(getattr(self, f"fc_{i-1}").bias, val=0.0)
        
        setattr(self, f"fc_{len(self.units)-2}", nn.Linear(self.units[-2], self.units[-1], bias=True))
        getattr(self, f"fc_{len(self.units)-2}").weight = torch.nn.Parameter(
            torch.from_numpy(
                weights[len(self.units)-2].T
            ).float()
        )
        nn.init.constant_(getattr(self, f"fc_{len(self.units)-2}").bias, val=0.0)
        
        if nfunc=='tanh' or nfunc=='sin':
            self.nfunc = getattr(torch, nfunc)
        else:
            self.nfunc = getattr(F, nfunc)
        
    def forward(self, x):
        
        for i in range(len(self.units)-2):
            x = self.nfunc(getattr(self, f"fc_{i}")(x))

        return getattr(self, f"fc_{len(self.units)-2}")(x)

# create the model with the same tf's initialization
model_torch = PhysicsInformedNNTorch(layers, 'tanh', weights)

# convert inputs to torch
x_t = torch.from_numpy(x_train).float()
y_t = torch.from_numpy(y_train).float()
t_t = torch.from_numpy(t_train).float()
x_t.requires_grad = True
y_t.requires_grad = True
t_t.requires_grad = True

# preprocess the data
X = torch.cat([x_t, y_t, t_t], dim=1)
lb, _ = X.min(0)
ub, _ = X.max(0)
H = 2.0*(X - lb)/(ub - lb) - 1.0

# model prediction
psi_and_p = model_torch(H)
psi = psi_and_p[:,0:1]
p = psi_and_p[:,1:2]

# create a dummy tensor for computing the gradient
dummy = torch.ones((5,1), requires_grad=False)

# gradients
u = torch.autograd.grad(psi, y_t, grad_outputs=dummy, create_graph=True, only_inputs=True)[0]
v = - torch.autograd.grad(psi, x_t, grad_outputs=dummy, create_graph=True, only_inputs=True)[0]

print("Psi:", psi)
print("U:", u)
print("V:", v)

This code reproduces the issue

I also tried a “raw” implementation (without the use of nn.Module) as it is in the original Tensorflow implementation. However I obtain the same result:

# this code is the continuation of the previous one
# PyTorch raw:
#-------------

def initialize_NN_torch(weights_tf):        
    weights_torch = []
    biases_torch = []
    
    for w_tf in weights_tf:
        w_torch = torch.from_numpy(
            w_tf
        ).float()
        w_torch.requires_grad = True
        b_torch = torch.zeros(w_tf.shape[1])
        b_torch.requires_grad = True

        weights_torch.append(w_torch)
        biases_torch.append(b_torch)  
        
    return weights_torch, biases_torch

def neural_net_torch(H, weights, biases):
    num_layers = len(weights) + 1

    for l in range(0,num_layers-2):
        W = weights[l]
        b = biases[l]
        H = torch.tanh(torch.add(torch.mm(H, W), b))
        
    W = weights[-1]
    b = biases[-1]
    Y = torch.add(torch.mm(H, W), b)
    
    return Y

# data
# convert inputs to torch
x_t2 = torch.from_numpy(x_train).float()
y_t2 = torch.from_numpy(y_train).float()
t_t2 = torch.from_numpy(t_train).float()
x_t2.requires_grad = True
y_t2.requires_grad = True
t_t2.requires_grad = True

# preprocess the data
X2 = torch.cat([x_t2, y_t2, t_t2], dim=1)
lb2, _ = X2.min(0)
ub2, _ = X2.max(0)
H2 = 2.0*(X2 - lb2)/(ub2 - lb2) - 1.0

weights_torch, biases_torch = initialize_NN_torch(weights)

psi_and_p = neural_net(H2, weights_torch, biases_torch)
psi = psi_and_p[:,0:1]
p = psi_and_p[:,1:2]

# create a dummy tensor for computing the gradient
dummy = torch.ones((5,1), requires_grad=False)

# gradients
u = torch.autograd.grad(psi, y_t2, grad_outputs=dummy, create_graph=True, only_inputs=True)[0]
v = - torch.autograd.grad(psi, x_t2, grad_outputs=dummy, create_graph=True, only_inputs=True)[0]

print("Psi:", psi)
print("U:", u)
print("V:", v)

For sake of completeness, I attach the results from the three approaches:

Tensorflow:

Psi: [[ 0.05774455]
 [ 0.08581608]
 [-0.13483115]
 [-0.07894165]
 [-0.00312972]]
U: [[0.048948  ]
 [0.05040096]
 [0.03029619]
 [0.04972227]
 [0.04007644]]
V: [[-0.07884889]
 [-0.07497925]
 [-0.07125001]
 [-0.07668343]
 [-0.0951325 ]]

PyTorch:

Psi: tensor([[ 0.0577],
        [ 0.0858],
        [-0.1348],
        [-0.0789],
        [-0.0031]], grad_fn=<SliceBackward>)
U: tensor([[ 0.0489],
        [-0.0578],
        [ 0.0303],
        [-0.0616],
        [ 0.0401]], grad_fn=<SliceBackward>)
V: tensor([[-0.0788],
        [ 0.1062],
        [ 0.1445],
        [-0.0767],
        [-0.0951]], grad_fn=<NegBackward>)

PyTorch “raw”:

Psi: tensor([[ 0.0577],
        [ 0.0858],
        [-0.1348],
        [-0.0789],
        [-0.0031]], grad_fn=<SliceBackward>)
U: tensor([[ 0.0489],
        [-0.0578],
        [ 0.0303],
        [-0.0616],
        [ 0.0401]], grad_fn=<SliceBackward>)
V: tensor([[-0.0788],
        [ 0.1062],
        [ 0.1445],
        [-0.0767],
        [-0.0951]], grad_fn=<NegBackward>)

Hi,

I am not 100% sure how tensorflow works. But it looks like you’re not computing the lower and upper bound in a differentiable manner in tensorflow no?

Hi,

well spotted.
I’m sorry, I wasted your time with this trivial error

Thanks a lot!

I’ve the exact same problem, would you mind explaining what the problem is?