Custom loss function leading to loss of gradients

I need to implement a custom loss function for a regression problem. The data that I have are one dimensional and each has a length of 1000. In total I have 100.000 (samples) of these 1000 (features). Each sample (which is a data of length 1000) is connected to 4 different parameters (a,b,m,s). The connection between the samples and parameters are known for the training data as usually.

I need to develop a deep learning model which will be able to output 4 parameters (a2,b2,m2,s2) given a vector of length 1000. The custom function that I need to define on Pytorch is then

                           loss =Integral (function(a,b,m,s,a2,b2,m2,s2))

The problem that I have is that I am not able to calculate this loss function inside “torch” framework using for example “torch.quad”, because such code does not exist in Pytorch. I decided to define a custom loss function using “quad” but this returns a float number. After I get this float number I can convert it to tensor. But the problem is that I am doing this at each iteration of the loop and the gradients are not evolving in time, rather they are initialized at each iteration again and again.

Here is the portion of my code (it is modified from1D convolutional nets):

def train(ep):
    model.train()
    total_loss = 0
    count = 0
    train_idx_list = np.arange(len(X_train), dtype="int32")
    np.random.shuffle(train_idx_list)

    for idx in train_idx_list: 
         data_line = X_train[idx]
         x = Variable(data_line)
         if args.cuda:
            x = x.cuda()

    optimizer.zero_grad()
    output = model(x.unsqueeze(0)).squeeze(0)
               
    a = (parameter1[idx]*40.0000)-20.0000
    b = parameter2[idx]*10
    m= parameter3[idx]*4.8063e-07
    s= parameter4[idx]*3.8284e-07

    a2=(output[0][0].item()*40.0000)-20.0000
    b2=output[0][1].item()*10
    m2=output[0][2].item()*4.8063e-07
    s2=output[0][3].item()*3.8284e-07
    
    loss = torch.tensor(quad(integrand, 0, 5*10**-7, args=(a[0],b[0],m[0],s[0],a2,b2,m2,s2), points=(max(0,min(m[0]-2*s[0],m2-2*s2)), max(m[0]+2*s[0],m2+2*s2)))[0], requires_grad=True, device=cuda0)
    total_loss += loss.item()
    count +=1# output.size(0)

    if args.clip > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
    loss.backward()
    optimizer.step()
    if idx > 0 and idx % args.log_interval == 0:
        cur_loss = total_loss / count
        print("Epoch {:2d} | lr {:.5f} | loss {:.5f}".format(ep, lr, cur_loss))
        total_loss = 0.0
        count = 0

The integrand is defined as follows:

   def integrand(y,a,b,m,s,a2,b2,m2,s2): 

   return abs(((1/2)*b*math.exp(1)**((-1)*(s**(-1))**b*abs((-1)*m+y)**b)*s**(-1)*math.erfc((-1)*2**(-1/2)*a*s**(-1)*((-1)*m+y))*math.gamma(b**(-1))**(-1))-((1/2)*b2*math.exp(1)**((-1)*(s2**(-1))**b2*abs((-1)*m2+y)**b2)*s2**(-1)*math.erfc((-1)*2**(-1/2)*a2*s2**(-1)*((-1)*m2+y))*math.gamma(b2**(-1))**(-1)));

How can one deal with such a problem?

Hi Seyhmus!

Pytorch’s autograd cannot compute gradients for computations that
are performed outside of the pytorch framework and you will not be
able to backpropagate through them – unless you give them some
help.

You will need to package your quad() computation as a Function
and provide it with a backward() function that computes its gradient.

See Extending torch.autograd for a start.

I haven’t looked carefully at integrand() and the underlying math,
but it would seem that one possibility could be the following:

You can switch the order of differentiation and integration (provided
things converge appropriately), so you could analytically differentiate
integrand(), code up that formula, and then integrate it numerically
(using your quad() or something similar) and use that as the main
ingredient in the backward() function of your custom autograd
Function.

Best.

K. Frank

Hello Frank,
Thank you very much for your response. I was able to write a function as follows:

from TCN.poly_music.myintegrate2 import integrand, integrand_alpha, integrand_beta, integrand_mu, integrand_sigma
from scipy.integrate import quad 
import torch

class my_loss(torch.autograd.Function):
    @staticmethod
    def forward(ctx,a,b,m,s,a2,b2,m2,s2):
        result = quad(integrand, 0, 5*10**-7, args=(a,b,m,s,a2,b2,m2,s2), points=(max(0,min(m-2*s,m2-2*s2)), max(m+2*s,m2+2*s2)))[0]
        ctx.save_for_backward(a,b,m,s,a2,b2,m2,s2)
        return torch.FloatTensor([result])


    @staticmethod
    def backward(ctx, grad_output):
        a,b,m,s,a2,b2,m2,s2 = ctx.saved_variables
        grad_alpha = quad(integrand_alpha, 0, 5*10**-7, args=(a,b,m,s,a2,b2,m2,s2), points=(max(0,min(m-2*s,m2-2*s2)), max(m+2*s,m2+2*s2)))[0]
        grad_beta = quad(integrand_beta, 0, 5*10**-7, args=(a,b,m,s,a2,b2,m2,s2), points=(max(0,min(m-2*s,m2-2*s2)), max(m+2*s,m2+2*s2)))[0]
        grad_mu = quad(integrand_mu, 0, 5*10**-7, args=(a,b,m,s,a2,b2,m2,s2), points=(max(0,min(m-2*s,m2-2*s2)), max(m+2*s,m2+2*s2)))[0]
        grad_sig = quad(integrand_sigma, 0, 5*10**-7, args=(a,b,m,s,a2,b2,m2,s2), points=(max(0,min(m-2*s,m2-2*s2)), max(m+2*s,m2+2*s2)))[0]
            
    
        return torch.FloatTensor(grad_alpha), torch.FloatTensor(grad_beta), torch.FloatTensor(grad_mu), torch.FloatTensor(grad_sig), None

In this function integrand, integrand_alpha, integrand_beta, integrand_mu and integrand_sigme are function of this sort:

def integrand(y,a,b,m,s,a2,b2,m2,s2): 

     return abs(((1/2)*b*math.exp(1)**((-1....)

And the first values that I get are as follows:

     grad_alpha = 0.210814
     grad_beta = -0.000110
     grad_mu =   563428.64
     grad_sig =   3272315.65

In the main program I used this one:

    loss = my_loss.apply(a[0],b[0],m[0],s[0],a2,b2,m2,s2)

Here the loss is of this sort:

    tensor([1.1734])

The rest of the code is the same as I had written in the original question. I got an error as this one:

    element 0 of tensors does not require grad and does not have a grad_fn

Later I checked in Internet and I saw that this error was related to no gradients in the tensor and a quick fix would be to use Variable after the definition of loss.

        loss = my_loss.apply(a[0],b[0],m[0],s[0],a2,b2,m2,s2)
        loss = Variable(loss, requires_grad = True)

This lets the code run but the results show that the loss function does not decrease

 Epoch  1 | lr 0.00100 | loss 1.62780
 Epoch  1 | lr 0.00100 | loss 1.63068
 Epoch  1 | lr 0.00100 | loss 1.61624
 Epoch  1 | lr 0.00100 | loss 1.62332
 Epoch  1 | lr 0.00100 | loss 1.62789
 Epoch  1 | lr 0.00100 | loss 1.64436

I think there is something missing about the Function that I defined, else I should have seen the loss function decreasing. To be on the safe side, I also checked the code for a single parameter and without custom loss function, i.e. with MSE loss that I defined over torch. This worked very well and gave some expected results. I calculated the gradients with Mathematica and tested them in Python and got the same results with Matlab. So the gradients should be OK.

What should then be missing with my Function or the appliaction of it in the main code?

Hi Seyhmus!

These three pieces of information strongly suggest that none of
the arguments you are passing into my_loss.apply() have
requires_grad = True. Try printing out the arguments right
before calling my_loss.apply() to see whether they show up
with requires_grad = True.

Looking at your code – and making some assumptions to fill in the
gaps – a, b, etc., come from parameter1, parameter2, etc., so I
have no reason to think that they have requires_grad = True.
Furthermore, even though a2, b2, etc., come from output, which,
as the result of model, presumably has requires_grad = True,
you then call .item(), which returns a python scalar (rather than a
tensor), and, as such, strips away any requires_grad property.

If you are trying to optimize the parameters of the model that
produces a2, b2, etc., then you have to move these calls to .item().
inside of my_loss.forward(). Pass a2, etc., into my_loss.apply() as
python tensors (and do check that they have requires_grad = True
to make sure that something else isn’t getting broken) and convert them
to python scalars (in order to pass them to quad()) with calls to .item()
inside of forward().

As a general rule, if you find yourself tempted to “manually” turn on
requires_grad = True somewhere in the middle of a calculation,
you should resist that temptation. Such a “quick fix” is almost never
an actual fix. Typically, you have “broken the computation graph”
somewhere upstream of your purported fix, and you need find and
fix the actual, upstream problem. (As noted above, in your case, the
calls to .item() are likely the culprit that has broken the computation
graph.)

As an aside, Variable has been deprecated for some time now, and
the autograd functionality in Variable has since been subsumed into
pytorch tensors, so just use regular tensors.

Your version of backward() has a flaw. For reasons I won’t go into,
it should work in the particular context in which you call it, but it’s
wrong because its output does not depend on its grad_output
argument.

Let’s assume for simplicity that you write a Function whose forward()
method takes a single tensor as an argument and returns a single
tensor. That is, it takes a collection of scalar arguments and returns a
collection of scalar results. The derivative of such a function is the
so-called Jacobian, that is, the matrix of partial derivatives of all of
the results with respect to all of the arguments.

Your backward() function – which we think of as calculating the
derivative of the corresponding forward() – doesn’t return the full
Jacobian (although it might choose to compute the full Jacobian
internally). Rather it returns the matrix-vector product of the Jacobian
matrix with the grad_output vector that pytorch’s autograd framework
passes into your backward() function.

For your my_loss Function to work correctly in (perfectly plausible)
more general contexts, its backward() method can’t simply ignore
its grad_output argument, but instead has to use it to compute (and
return) the expected Jacobian-grad_output product.

(The Jacobian-grad_output product returned by your backward()
will then be passed in as the grad_output argument of the backward()
of the Function next upstream in the computation graph. That’s how
autograd implements the chain rule.)

Best.

K. Frank