Error when running LBFGS to solve a non-linear inverse problem

This is my first time in these forums. Hence, please let me know if I could describe my issue with more clarity. I am only running on CPU right now, but will move on to powerful GPUs once I get it to work on CPU. I am using pytorch 1.6.0.

My intention is to use LBFGS in PyTorch to iteratively solve my non-linear inverse problem. I have a class for iteratively solving this problem. This class uses the LBFGS optimizer, specifically, with the following default parameters -

  self.optimizer = torch.optim.LBFGS(self.model.parameters(), lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=10, line_search_fn='strong_wolfe')

It has a method called “run” for performing the optimization, which is given below,

def run(self, data):

    data = torch.tensor(data, dtype=torch.double, requires_grad=False)

    def closure():
        if torch.is_grad_enabled():
            self.optimizer.zero_grad()
        data_pred = self.model()
        loss = self.loss_fn(data_pred, data)
        if loss.requires_grad:
            loss.backward()
        return loss

    loss = self.loss_fn(self.model(), data).detach().numpy()
    print('Init loss: ', loss)
    self.optimizer.step(closure)
    loss = self.loss_fn(
        self.model(), data).detach().numpy()
    print('Final loss: ', loss)

    # loss=closure().numpy()
    rec = self.model.rec.detach().numpy()
    return rec, loss

Note that “self.model” has trainable parameters that will be updated during LBFGS optimization. However, I get the following error in autograd. Weirdly, I don’t get the same error when using Adam. I know I am missing all the files in the error trace, but let me know if I can describe my issue in more detail.

File “/Users/mohan3/Desktop/Devs/Phase_Img/Code/phasetorch/phasetorch/nlpret.py”, line 111, in nlopt_phaseret
estor.run(meas_np)
File “/Users/mohan3/Desktop/Devs/Phase_Img/Code/phasetorch/phasetorch/pret.py”, line 48, in run
self.optimizer.step(closure)
File “/Users/mohan3/anaconda3/envs/phasetorch/lib/python3.8/site-packages/torch/autograd/grad_mode.py”, line 15, in decorate_context
return func(*args, **kwargs)
File “/Users/mohan3/anaconda3/envs/phasetorch/lib/python3.8/site-packages/torch/optim/lbfgs.py”, line 316, in step
flat_grad = self._gather_flat_grad()
File “/Users/mohan3/anaconda3/envs/phasetorch/lib/python3.8/site-packages/torch/optim/lbfgs.py”, line 255, in _gather_flat_grad
view = p.grad.view(-1)
RuntimeError: view size is not compatible with input tensor’s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(…) instead.

Thank you very much for your time and help!

Could you post the model definition, please?
It seems that an internal view() call is failing, since a grad tensor is apparently not contiguous, which seems to be a real bug.

Note that I get this error in Python 3.8.6 and pytorch 1.6.0 on a Mac OS. I tried debugging this problem further. You can reproduce the error by running the following standalone code. It seems like the error is related to numpy to tensor conversion. Towards the end of the below code, I describe an approach to get rid of this error by defining numpy arrays differently. So, it seems like the contiguous or discontiguous nature of arrays is inherited from numpy.

Also, how do I properly render the code below? Some parts are rendered, while some other appear as text.

import torch
import numpy as np


class ModelTest(torch.nn.Module):
    def __init__(self, N_y, N_x, d):
        super(ModelTest, self).__init__()
        tran = np.random.randn(d, 1, N_y, N_x, 2).astype(
            np.double, order='C')
        self.transform = torch.tensor(
            tran, dtype=torch.double, requires_grad=False)

    def forward(self, x):
        x = torch.fft(x, signal_ndim=2)
        y_real = x[:, :, :, :, 0]*self.transform[:, :, :, :, 0] - \
            x[:, :, :, :, 1]*self.transform[:, :, :, :, 1]
        y_imag = x[:, :, :, :, 0]*self.transform[:, :, :, :, 1] + \
            x[:, :, :, :, 1]*self.transform[:, :, :, :, 0]
        x = torch.stack((y_real, y_imag), dim=-1)
        x = torch.ifft(x, signal_ndim=2)
        x = x[:, :, :, :, 0]*x[:, :, :, :, 0]+x[:, :, :, :, 1]*x[:, :, :, :, 1]
        return x


class ModelTestOpt(torch.nn.Module):
    def __init__(self, init, N_y, N_x, d):
        super(ModelTestOpt, self).__init__()
        self.model = ModelTest(N_y, N_x, d)
        self.rec = torch.nn.Parameter(torch.tensor(
            init, dtype=torch.double, requires_grad=True))

    def forward(self):
        return self.model(self.rec)


class Sim_Algo:
    def __init__(self, N_y, N_x, d):
        self.model = ModelTest(N_y, N_x, d)

    def run(self, data):
        data = torch.tensor(data, dtype=torch.double, requires_grad=False)
        new_shape = (-1, data.size(-4), data.size(-3),
                     data.size(-2), data.size(-1))
        data = torch.reshape(data, new_shape)
        return self.model(data)


class Iter_Algo:
    def __init__(self, init, N_y, N_x, d):
        init = init.reshape(
            (-1, init.shape[-4], init.shape[-3], init.shape[-2], init.shape[-1]))
        self.model = ModelTestOpt(init, N_y, N_x, d)
        self.loss_fn = torch.nn.MSELoss(reduce=True, reduction='mean')
        self.optimizer = torch.optim.LBFGS(self.model.parameters(), lr=1, max_iter=20, max_eval=None,
                                           tolerance_grad=1e-07, tolerance_change=1e-09, history_size=10, line_search_fn='strong_wolfe')

    def run(self, data):
        data = torch.tensor(data, dtype=torch.double, requires_grad=False)

        def closure():
            if torch.is_grad_enabled():
                self.optimizer.zero_grad()
            data_pred = self.model()
            loss = self.loss_fn(data_pred, data)
            if loss.requires_grad:
                loss.backward()
            return loss

        loss = self.loss_fn(self.model(), data).detach().numpy()
        print('Init loss: ', loss)
        self.optimizer.step(closure)

        loss = self.loss_fn(
            self.model(), data).detach().numpy()
        print('Final loss: ', loss)

        rec = self.model.rec.detach().numpy()
        return rec, loss


N_y, N_x = 72, 96
d = 3

init = np.random.randn(1, N_y, N_x, 2).astype(dtype=np.double, order='C')
transform = np.random.randn(d, N_y, N_x, 2).astype(dtype=np.double, order='C')

simtor = Sim_Algo(N_y, N_x, d)
data = simtor.run(init)

# Comment next three lines to prevent error
init = np.random.randn(1, N_y, N_x)+1j*np.random.randn(1, N_y, N_x)
init = np.stack((np.real(init), np.imag(init)),
                axis=0).astype(dtype=np.double, order='C')
init = np.moveaxis(init, 0, -1)
# Uncomment below line to prevent error
#init = np.random.randn(1, N_y, N_x, 2).astype(dtype=np.double, order='C')
solver = Iter_Algo(init, N_y, N_x, d)
solver.run(data)

Thanks for the code.
self.rec = torch.nn.Parameter(torch.from_numpy(init)) is creating a tensor, which is not contiguous due to the moveaxis operation on the numpy array.
Use:

self.rec = torch.nn.Parameter(torch.from_numpy(init).contiguous())

and the code should work.

You can add code snippets by wrapping them into three backticks ```. I’ve formatted your code for you.

Yes, your modification does fix the error. Thanks a lot!

However, the error message is a bit cryptic since it arises from pytorch back-propagation internals. If back-propagation requires contiguous tensors, it may be useful to have an assertion like error check somewhere to prevent coders from using discontiguous trainable tensors.

A better error message would surely be useful.
However, I don’t know if forcing contiguous tensors would break other use cases and as you’ve described before, some optimizers seem to work fine with non-contiguous tensors.

True! As far as the error check goes, since it appears only LBFGS require contiguous tensors, a good place to check for contiguous trainable tensors is within the LBFGS class’ init function since it anyway receives all the trainable tensors (or parameters) as input.

@albanD do you think a check should be added into LBFGS or would this have any side effects?

Making the Tensor contiguous won’t work for sure as it would break inplace ops.

Adding a check for .is_contiguous() and raise an error if it is not might be overly strict as you can sometimes do the view even when it is not contiguous.

it arises from pytorch back-propagation internals.

It does not actually happens in the backprop internal but in the lbfgs optimizer.
That being said, we might be able to catch this error and improve the error message?
Also @ptrblck note that the error is from the .grad field no here?

Yeah, you are right.
A non-contiguous input tensor would also create non-contiguous gradients or is this not always true?

I think _gather_flat_grad might get an additional check or a better error message? What do you think?

We try to follow in the input layout in most cases but this is not guaranteed indeed.

I think _gather_flat_grad might get an additional check or a better error message? What do you think?

We definitely can improve that error message to at least recommend to the user to make his weights contiguous as a first step (but we do not want to enforce it I think as it will break valid user code when the view is valid on a non-contiguous Tensor).

1 Like