Computationally similar functions result in different behavior. Am I missing someting wrt to autograd

I have to refactor a old physics informed neural network code for further experiments. The old code is in notebook form and has lot of explicit computational expressions. As you can see in the new implementation, it is much more flexible which is necessary for our work right now. The following are two snippets of a function that I found out to be the culprit. I have replaced all the old code with new code incrementally and test out by training the model and comparing metrics and visualizing the functions modeled by the PINN. All of my new code changes have no unintended effect and reproduce the results of old notebook with error differences in the order of 1e-5.

Only when I use the new net_f function as defined in the training the training gets slow by a large measure. For instance with old notebook’s net_f by 1900 iterations, the error reached 1e-03 range where as using new net_f definition below the error is still in 1e-1 range.

Both implementation use the LBFGS optimizer with following config

    # optimizers: using the same settings
          self.optimizer = torch.optim.LBFGS(
              self.dnn.parameters(),
              lr=1.0,
              max_iter=50000,
              max_eval=50000,
              history_size=50,
              tolerance_grad=1e-5,
              tolerance_change=1.0 * np.finfo(float).eps,
              line_search_fn="strong_wolfe"       # can be "strong_wolfe"
          )

Note: The names is just string uvpsab

   def cleave_and_name(self, input_tensor):
       return OrderedDict([(name, input_tensor[:, i:j])
                           for name, i, j in zip(self.names,
                                                 range(NCOMP),
                                                 range(1, NCOMP+1))])
   def name_the_tensors(self, tensors):
       return OrderedDict([(name, i) for name, i in zip(self.names, tensors)])

   def net_f(self, x, t):
       """ The pytorch autograd version of calculating residual """
       y = self.net_u(x, t)

       F = self.cleave_and_name(y).values()

       def compute_grad(diff, wrt):
           return torch.autograd.grad(diff, wrt,
                                      grad_outputs=torch.ones_like(diff),
                                      retain_graph=True,
                                      create_graph=True)[0]

       Ft = [compute_grad(i, t) for i in F]
       Fx = [compute_grad(i, x) for i in F]
       Fxx = [compute_grad(i, x) for i in Fx]

       def compute_output(f, f_t, f_xx, c1, c2, c1_sign):
           return f_t + (c1_sign * c1 * f_xx) + ((-1 * c1_sign) * c2 * f)

       N_F = len(F)
       SQ = sum([ i**2 for i in F ])
       C1 = [0.5] * N_F
       C2 = [SQ] * N_F
       C1_SIGN = [pow(-1, i) for i in range(1, N_F+1)]

       F_output = [compute_output(*i)
                   for i in zip(F, Ft, Fxx, C1, C2, C1_SIGN)]

       return self.name_the_tensors(F_output)

The old implementation is as follows.

    def net_f(self, x, t):
        """ The pytorch autograd version of calculating residual """
        y = self.net_u(x, t)

        u = y[:, 0:1]
        v = y[:, 1:2]
        p = y[:, 2:3]
        s = y[:, 3:4]
        a = y[:, 4:5]
        b = y[:, 5:6]

        u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        v_t = torch.autograd.grad(v, t, grad_outputs=torch.ones_like(v), retain_graph=True, create_graph=True)[0]

        p_t = torch.autograd.grad(p, t, grad_outputs=torch.ones_like(p), retain_graph=True, create_graph=True)[0]
        s_t = torch.autograd.grad(s, t, grad_outputs=torch.ones_like(s), retain_graph=True, create_graph=True)[0]

        a_t = torch.autograd.grad(a, t, grad_outputs=torch.ones_like(a), retain_graph=True, create_graph=True)[0]
        b_t = torch.autograd.grad(b, t, grad_outputs=torch.ones_like(b), retain_graph=True, create_graph=True)[0]

        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        v_x = torch.autograd.grad(v, x, grad_outputs=torch.ones_like(v), retain_graph=True, create_graph=True)[0]

        p_x = torch.autograd.grad(p, x, grad_outputs=torch.ones_like(p), retain_graph=True, create_graph=True)[0]
        s_x = torch.autograd.grad(s, x, grad_outputs=torch.ones_like(s), retain_graph=True, create_graph=True)[0]

        a_x = torch.autograd.grad(a, x, grad_outputs=torch.ones_like(a), retain_graph=True, create_graph=True)[0]
        b_x = torch.autograd.grad(b, x, grad_outputs=torch.ones_like(b), retain_graph=True, create_graph=True)[0]


        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), retain_graph=True, create_graph=True)[0]
        v_xx = torch.autograd.grad(v_x, x, grad_outputs=torch.ones_like(v_x), retain_graph=True, create_graph=True)[0]

        p_xx = torch.autograd.grad(p_x, x, grad_outputs=torch.ones_like(p_x), retain_graph=True, create_graph=True)[0]
        s_xx = torch.autograd.grad(s_x, x, grad_outputs=torch.ones_like(s_x), retain_graph=True, create_graph=True)[0]

        a_xx = torch.autograd.grad(a_x, x, grad_outputs=torch.ones_like(a_x), retain_graph=True, create_graph=True)[0]
        b_xx = torch.autograd.grad(b_x, x, grad_outputs=torch.ones_like(b_x), retain_graph=True, create_graph=True)[0]


        SQ = u**2 + v**2 + p**2 + s**2 + a**2 + b**2

        f_u = v_t - 0.5*u_xx + SQ*u

        f_v = u_t + 0.5*v_xx - SQ*v

        f_p = s_t - 0.5*p_xx + SQ*p

        f_s = p_t + 0.5*s_xx - SQ*s

        f_a = b_t - 0.5*a_xx + SQ*a

        f_b = a_t + 0.5*b_xx - SQ*b


        return f_u, f_v, f_p, f_s, f_a, f_b

Hi Paarulakan!

It’s pretty hard to understand what you’re doing here.

Do all of your “new code changes” include your new “net_f function?”

What is it that doesn’t change by more than 1e-5? One forward pass?
One backward pass? The result of an optimization step?

What is self.dnn?

Is self.dnn the same as self.net_u?

LBFGS.step() takes a closure that, among other things, returns
a scalar loss value. You haven’t showed us your closure nor your
call to LBFGS.step(), so we have no idea where things might be
going wrong.

The differences of 1e-5 you mention would seem to be consistent
with single-precision round-off error. But, in any event, you should
try repeating your computation in double precision to make sure
that your new vs. old code isn’t introducing some error not due to
round-off error.

Presumably you compute a scalar loss value somewhere. Try
performing a single forward pass with both your old and new net_f.
Do the two computations of your loss value agree up to round-off
error?

It’s not at all clear which parameters you are optimizing and with
respect to which parameters you would be computing gradients.
But, whatever they are, try performing a single backward pass
with both your old and new net_f. Do the resulting gradients
agree up to round-off error?

Then try performing a single optimization step, and so on.

Good luck.

K. Frank

1 Like

First of all thank you so much for going to the trouble of reading this and responding kindly.

Yes. So this is how it went. There is original notebook that produces a set of plot. I refactored the notebook to be more generic/easily adaptable which includes everything from the original notebook. But when I ran my new code the result were not replicated. The training was orders of magnitude slow i.e the loss trend (decrease per iteration) did not follow the original notebook’s loss trend. So replaced code blocks in my new code with code blocks from original notebook and tested for loss trend. everything in my new code works except net_f. All of my new implementations of the original functions act computationally same except for net_f.

it is the difference in error value of new code(with old impl of net_f) vs old code after training occurred for 10k iterations. So net_f is the culprit. I have provided both the implementations of net_f in the OP(top one new, bottom one old).

Pretty much. self.net_u() is also a function which just concats [x,t] and runs self.dnn on that.

Yes, there is function named loss_func() which is the closure supplied to optimizer.step(). loss_func runs both net_u and net_f and performs gradient descent using mean squared error. the new impl of loss_func is shown below. This works. I confirmed it by running for 10k iterations and results are replicated. If you want old impl of loss_func I can share the code, but it is too lengthy and basically inlined variant of the new loss_func similar to old net_f impl.

def loss_func(self):
        self.iter += 1
        self.optimizer.zero_grad()

        pred_icbc = self.net_u(self.x_u, self.t_u)
        pred_f = self.net_f(self.x_f, self.t_f)

        pred_icbc = self.cleave_and_name(pred_icbc)
        loss_icbc = OrderedDict([
            (name, torch.mean((self.icbc[name] - pred_icbc[name])**2)) for name in self.names
        ])

        loss_f = OrderedDict([
            (name, torch.mean(pred_f[name]**2)) for name in self.names
        ])

        loss = sum(loss_icbc.values()) + sum(loss_f.values())
        loss.backward()

        if self.iter % self.checkpoint == 0:
            elapsed = time.time() - self.training_start_time
            print(f'iteration: {self.iter} / elapsed: {elapsed} / loss: {loss.item():3e}')
            print('  icbc loss:' + '/'.join([f'{n}: {l.item():3e}' for n, l in loss_icbc.items()]))
            print('  f loss:' + '/'.join([f'{n}: {l.item():3e}' for n, l in loss_f.items()]))

            torch.save({
                'iter': self.iter,
                'model_state_dict': self.dnn.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss,
            },'model.pt')

        return loss

The thing that puzzles me is that the both net_f functions are semantically same and should be computationally same too which they are evidently not. I thought my lack of understanding autograd might be a problem.

Hi Paarulakan!

In general, trying to debug by observing some numerical discrepancy
after 10k iterations is a bad approach. All of those iterations permit
round-off error to accumulate, mix things together, and generally
confuse the issue.

This makes sense. loss_func() (the LBFGS closure) performs
one forward pass followed by a backward pass and returns the
loss value.

Start with this (leaving LBFGS out of it). First call loss_func()
once with your old and new net_f() functions and compare both
the loss value and the computed gradients (that I assume get
populated into the .grad properties of self.dnn’s Parameters).

Do these values agree up to some reasonable round-off error (after
just a single forward / backward pass, i.e, after just a single call to
loss_func())? Do this in both single and double precision just to
make sure that you’re not fooling yourself.

If they do agree, perform a single call to LBFGS.step (loss_func)
and compare the updated Parameter values (presumably those of
self.dnn). Again, try doing this in both single and double precision
to make sure that a “real” discrepancy isn’t hiding itself behind what
might look like round-off error.

The basic idea is to perform divide-and-conquer debugging where
you try to locate where the error first occurs (or at least where the
discrepancy can first be detected). Also, don’t focus on “derived”
quantities like how your training is converging, but instead on more
basic quantities like the loss value, the gradients, or the updated
parameter values.

If you don’t see a discrepancy after a single LBFGS.step(), try
running (smallish) numbers of iterations until you do.

Conversely, if you do see a discrepancy (but can’t track down its
cause), try running plain-vanilla SGD as your optimizer where you
can run a single step as a time (LBFGS takes multiple steps under
the hood), and then increase the number of SGD iterations until
you can detect a discrepancy.

Assuming that you’re right about net_f() being the culprit, after
you’ve located the first detectable discrepancy, you can now print
out intermediate values in both the old and new net_f()s to nail
down which specific operation differs between the two versions.

Good luck!

K. Frank

1 Like