Memory issues with iterative algorithm

So I am training a network for image reconstruction as part of an unrolled optimization scheme, where first I apply a neural network to an image to get a denoised image, then I apply a data consistency step where I do a conjugate gradient algorithm with respect to the denoised image. I apply . So the forward function of the method in pseudo code is

x=init_guess

for i in num_iterations:
denoised_x=neural_network(x)

 data_consistent_x=conj_grad(data,denoised_x)

x=data_consistent_x

However, while training the inclusion of the conj grad step is adding hugely to the GPU memory used (without: 2GB, with: 45GB).

This conjugate gradient algorithm (Conjugate gradient method - Wikipedia) is an iterative method involving matrix operations. so there is another for loop inside.

Below is the function I am using, where init_image is x, and the masked_kspace,sens_mps,mask are the data.

def conjg_grad_correction(self,init_image,masked_kspace,sens_mps,mask):

    r=self.my_rhs(masked_kspace,sens_mps,mask,init_image)

    p=r

    r_complex=torch.complex(r[...,0],r[...,1])
    p_complex=torch.complex(p[...,0],p[...,1])
    
    rtr=(torch.conj(r_complex)*r_complex).sum()
    x_complex=torch.zeros_like(r_complex)

    for i in range(10):
        Ap=self.myAtAp(p,sens_mps,mask)
        Ap_complex=torch.complex(Ap[...,0],Ap[...,1])
        #p_complex=torch.complex(p[...,0],p[...,1])
        alpha=rtr/((torch.conj(p_complex)*Ap_complex).sum())
        #print(alpha)
        x_complex=x_complex+alpha*p_complex
        r_complex=r_complex-alpha*Ap_complex
        rTrnew=(torch.conj(r_complex)*r_complex).sum()
        #print(rTrnew)
        beta=rTrnew/rtr
        #print(beta)
        p_complex=r_complex+beta*p_complex
        p=torch.view_as_real(p_complex)

        rtr=rTrnew
    
    return torch.view_as_real(x_complex)

I am rather new to pytorch, so I am wondering if it is possible to reduce the memory used during training by modifying this conj grad implementation?

Thanks in advance!

Hi Mithrandir!

I don’t understand your use case or follow what you are doing in your
code, so I can’t give you a concrete answer. But I do have some
comments, below.

First, it is perfectly fine to use a loop or iterative algorithm in your
forward pass. As long as the calculations inside of your loop use
(differentiable) pytorch tensor operations (that support autograd),
autograd and backpropagation will work properly.

As you’ve seen, however, each iteration through the loop will add
another “layer” to the “computation graph,” consuming memory, and
taking time when you backpropagate, perhaps to the extent that you
run out of memory or the computation becomes impractically slow.

Depending on your use case, you may be able to turn off autograd for
your iterative forward function, and write your own corresponding
backward function that requires less memory (and/or time).

The “baseline” use of the conjugate-gradient algorithm is as a linear
solver where the linear operation A . x (with A a matrix, and x a
vector) is given implicitly somehow, or where A is sparse so that
computing A . x is cheaper than computing a full matrix-vector
product or it is impractical to store the full A^-1 in memory.

Let’s say that this is your use case, so that you are solving A . x = b
for x, and you require gradients of the computed x with respect to the
elements of the vector b.

Then x = A^-1 . b, so that the Jacobean (of x with respect to b) is
A^-1. In backpropagation you typically compute the dot product of
such a Jacobean with a vector of numerical gradients that are being
backpropagated from later stages of your network. Let’s call this
vector g and the result of the dot product r = A^-1 . g.

r can now be obtained by solving A . r = g for r for which you can
again employ the conjugate-gradient algorithm, taking advantage of
any sparsity in A (or if you need to because the linear operation
A . x is only given implicitly).

You can also write your own backward function, potentially using the
conjugate-gradient algorithm, in the case that you require the gradients
of the computed x with respect to the elements of the matrix A. Doing
so is relatively straightforward, but may or may not be helpfully efficient
or practical depending on the details of the sparsity structure of A.

Best.

K. Frank