TF and Pyt gradient are not identical?

hi

here is a strange issue,

import torch
from torch import nn
from torch.autograd import Variable, grad
import tensorflow as tf

#define network weights and input
A_LOC_W = [2.]
A_SCALE_W = [3.]
B_LOC_W = [2.]
B_SCALE_W = [3.]
INPUT_NUM = 1.

#PyTorch A network
class PT_A(nn.Module):
    def __init__(self):
        super(PT_A, self).__init__()
        self.loc_w = Variable(torch.FloatTensor(A_LOC_W), requires_grad=True)
        self.scale_w = Variable(torch.FloatTensor(A_SCALE_W), requires_grad=True)

    def forward(self, z):
        loc = self.loc_w * z
        scale = self.scale_w * z
        z = z * scale + loc
        return z, loc, scale

#Tensorflow A network
class TF_A:
    def __init__(self):
        super(TF_A, self).__init__()
        self.loc_w = tf.Variable(A_LOC_W)
        self.scale_w = tf.Variable(A_SCALE_W)

    def forward(self, z):
        loc = tf.multiply(self.loc_w, z)
        scale = tf.multiply(self.scale_w, z)
        z = tf.multiply(z, scale) + loc
        return z, loc, scale

#PyTorch B network  
class PT_B(nn.Module):
    def __init__(self):
        super(PT_B, self).__init__()
        self.loc_w = Variable(torch.FloatTensor(B_LOC_W), requires_grad=True)
        self.scale_w = Variable(torch.FloatTensor(B_SCALE_W), requires_grad=True)

    def forward(self, x):
        loc = self.loc_w * x
        scale = self.scale_w * x
        return loc, scale

#Tensorflow B network
class TF_B:
    def __init__(self):
        super(TF_B, self).__init__()
        self.loc_w = tf.Variable(B_LOC_W)
        self.scale_w = tf.Variable(B_SCALE_W)

    def forward(self, x):
        loc = tf.multiply(self.loc_w , x)
        scale = tf.multiply(self.scale_w , x)
        return loc, scale


#PyTorch Code        
a_net = PT_A()
b_net = PT_B()

#input -> A
input = Variable(torch.FloatTensor([INPUT_NUM]))
a_pred, a_loc, a_scale = a_net(input)
#output of A -> B
b_loc, b_scale = b_net(a_pred)

#compute KL loss
a_dis = torch.distributions.normal.Normal(loc=a_loc, scale=a_scale)
b_dis = torch.distributions.normal.Normal(loc=b_loc, scale=b_scale)
#kl_loss = torch.distributions.kl._kl_normal_normal(a_dis, b_dis)
kl_loss = torch.log(b_scale/a_scale) + (a_scale**2 + (a_loc - b_loc)**2) / (2 * b_scale**2) - 0.5


#compute grads
grads = grad(kl_loss, [kl_loss, a_loc, a_scale, b_loc, b_scale, a_pred])
kl_loss_loss,a_loc_grad, a_scale_grad, b_loc_grad, b_scale_grad, a_pred_grad = grads[0].item(), \
                                                                 grads[1].item(), \
                                                                 grads[2].item(), \
                                                                 grads[3].item(), \
                                                                 grads[4].item(), \
                                                                 grads[5].item()

#print
print ("PyTorch a_loc value:{}, grad:{}".format(a_loc.item(), a_loc_grad))
print ("PyTorch a_scale value:{}, grad:{}".format(a_scale.item(), a_scale_grad))
print ("PyTorch b_loc value:{}, grad:{}".format(b_loc.item(), b_loc_grad))
print ("PyTorch b_scale value:{}, grad:{}".format(b_scale.item(), b_scale_grad))
print ("PyTorch a_pred value:{}, grad:{} ".format(a_pred.item(), a_pred_grad))
print ("PyTorch kl_loss {}, grad:{}".format(kl_loss.item(),kl_loss_loss))

########Tensorflow Code
_net = TF_A()
b_net = TF_B()

#input -> A
input = tf.placeholder(tf.float32, [1,])
a_pred, a_loc, a_scale = a_net.forward(input)
#output of A -> B
b_loc, b_scale = b_net.forward(a_pred)

#compute KL loss
a_dis = tf.distributions.Normal(loc=a_loc, scale=a_scale)
b_dis = tf.distributions.Normal(loc=b_loc, scale=b_scale)
kl_loss = a_dis.kl_divergence(b_dis)

#compute grads
a_loc_grad = tf.gradients(kl_loss, [a_net.loc_w])[0]
a_scale_grad = tf.gradients(kl_loss, [a_net.scale_w])[0]
b_loc_grad = tf.gradients(kl_loss, [b_net.loc_w])[0]
b_scale_grad = tf.gradients(kl_loss, [b_net.scale_w])[0]
a_pred_grad = tf.gradients(kl_loss, [a_pred])[0]


sess = tf.Session()
sess.run(tf.initialize_all_variables())
vals = sess.run([a_loc, a_loc_grad, \
                 a_scale, a_scale_grad, \
                 b_loc, b_loc_grad, \
                 b_scale, b_scale_grad, \
                 a_pred, a_pred_grad, \
                 kl_loss], {input:[INPUT_NUM]})
#print
print ("Tensorflow a_loc value:{}, grad:{}".format(vals[0][0], vals[1][0]))
print ("Tensorflow a_scale value:{}, grad:{}".format(vals[2][0], vals[3][0]))
print ("Tensorflow b_loc value:{}, grad:{}".format(vals[4][0], vals[5][0]))
print ("Tensorflow b_scale value:{}, grad:{}".format(vals[6][0], vals[7][0]))
print ("Tensorflow a_pred value:{}, grad:{} ".format(vals[8][0], vals[9][0]))
print ("Tensorflow kl_loss value:{} ".format(vals[10][0]))

I used exactly the same setting ( two simple networks, use both the first network and second networks’s output create two norm distributions and use the KL of the two pdfs as loss… the first networks output goes directly in to the second network)

The strange thing is that with the above exact settings, the TF’s grads are different with PTY’s grads. and I did a manual check that the TF’s results are identical with mine

Can you please help me to understand this?

plus: is it possible that this difference root from this peculiar architecture?

PyTorch a_loc value:2.0, grad:0.1706666797399521
PyTorch a_scale value:3.0, grad:-0.11377778649330139
PyTorch b_loc value:10.0, grad:0.035555556416511536
PyTorch b_scale value:15.0, grad:0.04503703862428665
PyTorch a_pred value:5.0, grad:0.20622223615646362 
PyTorch kl_loss 1.2716602087020874, grad:1.0

Tensorflow a_loc value:2.0, grad:0.1706666797399521
Tensorflow a_scale value:3.0, grad:-0.113777756690979
Tensorflow b_loc value:10.0, grad:0.17777778208255768
Tensorflow b_scale value:15.0, grad:0.22518518567085266
Tensorflow a_pred value:5.0, grad:0.20622223615646362 
Tensorflow kl_loss value:1.2716600894927979 
3 Likes

Hi
Many many thanks if you could have a look
@smth and @ptrblck thanks!!

You are calculating gradients for different variables in case of Pytorch than in Tensorflow.

In Pytorch gradient calculation, please change the line to

grads = grad(kl_loss, [kl_loss, a_net.loc_w, a_net.scale_w, b_net.loc_w, b_net.scale_w, a_pred])

1 Like

thanks a million, I think you are right!!