How to use the partial derivative of an op as new op?

I want to use the partial derivative of a function, but I don’t want to manually calculate it so I use autograd to calculate it for me.

This seems to work (?) see below.

Problem: While calculating the partial derivative w.r.t x of f(x,y) it also calculates the partial derivative w.r.t all parents of both x and y while I only want to do it for x in the forward pass. I can overcome this by zeroing out gradients but it still seems like a waste of work. How can I do this better?

Here f(x,y) = (x*y)^2

import torch
print(torch.__version__)
#> '0.4.0a0+059299b'

from torch.autograd import Variable
from torch import nn
from torch.nn.parameter import Parameter
import torch.optim as optim

def dfdx(x,param,f):
    x = Variable(x.data,requires_grad=True)
    z = f(x,param)
    z.backward([torch.ones_like(z)],create_graph=True)
    dfdx = x.grad

    return dfdx

n_batch = 5
x = Variable(3*torch.ones(n_batch), requires_grad = False)
w = Parameter(.5*torch.ones(1), requires_grad = True)
p = Variable(torch.ones(n_batch))*w
# Forward pass 
dfdx_val = dfdx(x,p,lambda x,p : (p*x)**2)
print('df/dx')
print('result :',dfdx_val.data.numpy().T) 
print('sought :',((p*p)*2*x).data.numpy().T) # (p*p)*2*x 
# > df/dx
# > result : [[1.5 1.5 1.5 1.5 1.5]]
# > sought : [[1.5 1.5 1.5 1.5 1.5]]

print('w.grad :',w.grad) # <-- PROBLEM. I didn't want this to be calculated.
# > w.grad : Variable containing:
# > 135

# Backward pass; 
# Since a .backward() is called during forward in dfdx, 
# need to call it *After* Forward
w.grad.data.zero_() 
dfdx_val.mean().backward() # will only work on scalar! 
print('df/dxdp')
print('result :',(w.grad).data.numpy().T) 
print('sought :',(4*w*x).data.numpy()[:1].T) # 4*p*x 
# > df/dxdp
# > result : [[6.]]
# > sought : [[6.]]

It might help to think of y in f(x,y) as input from layers above while x is a query.
Keywords; grad of grad, higher order derivatives, partial derivative as functional/layer…

Here is an example of application btw. (Maybe others can find it useful). It works but I think it’s doing one full redundant backward pass.

def dfdx(x,param,f):
    x = Variable(x.data,requires_grad=True)
    z = f(x,param)
    z.backward([torch.ones_like(z)],create_graph=True)
    dfdx = x.grad

    return dfdx

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(1, 1)
        self.linear2 = nn.Linear(1, 1)
        self.functional = lambda x,y : 100+(x+y)**2

    def forward(self, x):
        x = self.linear1(x)
        z = self.linear2(x).tanh() # Should be learned to be 0
        # this functional should be 2(x+y) if diff works. 
        x = dfdx(z,x,self.functional) 
        # x = dfdx(x,z,self.functional) # Both works. 
        return x

model = Model()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

n = 5
losses = []
x = Variable(torch.ones(n,1).cumsum(0)) #1,2,3,4,5
for i in range(1000):
    def closure():
        x = Variable(torch.randn(n,1)) 
        y = model(x)
        loss = criterion(y,x)
        optimizer.zero_grad()
        loss.backward()
        print(loss.data.numpy())
        return loss    
    optimizer.step(closure)
print(model(x))
1 Like

I usually prefer torch.autograd.grad for anything except differentiating a target function to take an optimizer step.
That also keeps you from accumulating things the wrong things…
So instead of dfdx(x, z, self.functional) you do
f = self.functional(x,z) and
dfdx,dfdz = torch.autograd.grad([f],[x,z], retain_graph=True, create_graph=True). (retain_graph only if you have need for it, i.e. get called backward twice errors without.)

Best regards

Thomas

3 Likes

Some timing results:

time error untrained error trained
manual differentiation 0:03:16 353729.12 10.401332
autodiff with .grad and .sum() 0:04:00 353527.53 16.774494
autodiff with .grad and grad_outputs ones 0:03:53 353527.53 16.774494
autodiff with .grad and .sum() w.r.t both inputs 0:03:51 353527.53 16.774494
autodiff with .backward() 0:05:32 353729.12 3889350.8

Here manual differentiation is to be considered ground truth. Autodiff does add some overhead. It seems like using grad is the winner w.r.t speed but numerically it (marginally) differs from ground truth.

When using .backward() as I initially proposed, the training loss is going down and the forward propagation before training is numerically identical to ground truth. I don’t know what happened after but I think its accumulating gradients somehow somewhere in a bad way.

Code

import torch
from datetime import datetime 
# print(torch.__version__)
#> '0.4.0a0+059299b'

from torch.autograd import Variable
from torch import nn
from torch.nn.parameter import Parameter
import torch.optim as optim

def dfdx(x,param,f):
    f_val = f(x,param)
    dfdx_val, = torch.autograd.grad(
                                    outputs= [f_val.sum()],
                                    inputs = [x], 
                                    grad_outputs=None,
                                    retain_graph=True, 
                                    create_graph=True,
                                    only_inputs=True, 
                                    allow_unused=True
                                    )

    return dfdx_val

def dfdx2(x,param,f):
    f_val = f(x,param)
    dfdx_val, = torch.autograd.grad(
                                    outputs= [f_val],
                                    inputs = [x], 
                                    grad_outputs=torch.ones_like(f_val),
                                    retain_graph=True, 
                                    create_graph=True,
                                    only_inputs=True, 
                                    allow_unused=True
                                    )
    return dfdx_val

def dfdx3(x,param,f):
    f_val = f(x,param)
    dfdx_val,_ = torch.autograd.grad(
                                    outputs= [f_val.sum()],
                                    inputs = [x,param], 
                                    grad_outputs=None,
                                    retain_graph=True, 
                                    create_graph=True,
                                    only_inputs=True, 
                                    allow_unused=True
                                    )

    return dfdx_val

def dfdx4(x,param,f):
    x = Variable(x.data,requires_grad=True)
    z = f(x,param)
    z.backward([torch.ones_like(z)],create_graph=True)
    dfdx = x.grad

    return dfdx

class Model(torch.nn.Module):
    def __init__(self,manual_diff):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(1, 1)
        self.linear2 = nn.Linear(1, 1)
        self.unneccessary_op1 = nn.Linear(1, 100)
        self.unneccessary_op2= nn.Linear(100, 1)
        self.unneccessary_op3= nn.Linear(1, 1)
        # this functional should be 2(x+0.01*y**2) if diff works.
        self.functional = lambda x,y : 100+(x+0.01*y**2)**2
        self.functionaldfdx = lambda x,y : 2*(x+0.01*y**2)

        self.manual_diff = manual_diff
    def forward(self, x):
        # Add stupid ops to make sure the .sum() in dfdx is not only real work.
        x = self.unneccessary_op1(x)
        x = self.unneccessary_op2(x)
        x = self.unneccessary_op3(x)
        x = self.linear1(x)
        z = self.linear2(x) # Should be learned to be constant 0

        if self.manual_diff:
            x = self.functionaldfdx(x,z)
        else:
            x = dfdx(x,z,self.functional)
        return x
criterion = nn.MSELoss()

def experiment():
    start_time = datetime.now() 
    batch_size = 3000
    for i in range(100001):
        def closure():
            x = Variable(torch.randn(batch_size,1))
            y = model(x)
            loss = criterion(y,x)
            optimizer.zero_grad()
            loss.backward()
            # if i%10000==0:
            #     print(i,datetime.now() - start_time,loss.data.numpy())
            return loss    
        optimizer.step(closure)
    time_elapsed = datetime.now() - start_time 

    print('Time elapsed {}'.format(time_elapsed))

def evaluate():
    x = Variable(torch.ones(1000,1).cumsum(0))
    y_pred = model(x)
    y_solution = x
    return(criterion(y_pred,y_solution).data.numpy())

torch.manual_seed(1)
model = Model(manual_diff = True)
print('Using manual differentiation')
print('error before: ',evaluate())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
experiment()
print('error after: ',evaluate())


torch.manual_seed(1)
model = Model(manual_diff = False)
print('Using autodiff with .sum()')
print('error before: ',evaluate())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
experiment()
print('error after: ',evaluate())

dfdx = dfdx2
torch.manual_seed(1)
model = Model(manual_diff = False)
print('Using autodiff with grad_outputs ones')
print('error before: ',evaluate())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
experiment()
print('error after: ',evaluate())

dfdx = dfdx3
torch.manual_seed(1)
model = Model(manual_diff = False)
print('Using autodiff with .sum() w.r.t both inputs')
print('error before: ',evaluate())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
experiment()
print('error after: ',evaluate())

dfdx = dfdx4
torch.manual_seed(2)
model = Model(manual_diff = False)
print('Using autodiff with .backward()')
print('error before: ',evaluate())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
experiment()
print('error after: ',evaluate())

Also, big thanks to Yunjey for sharing this gist