So I am training a network for image reconstruction as part of an unrolled optimization scheme, where first I apply a neural network to an image to get a denoised image, then I apply a data consistency step where I do a conjugate gradient algorithm with respect to the denoised image. I apply . So the forward function of the method in pseudo code is
x=init_guess
for i in num_iterations:
denoised_x=neural_network(x)
data_consistent_x=conj_grad(data,denoised_x)
x=data_consistent_x
However, while training the inclusion of the conj grad step is adding hugely to the GPU memory used (without: 2GB, with: 45GB).
This conjugate gradient algorithm (Conjugate gradient method - Wikipedia) is an iterative method involving matrix operations. so there is another for loop inside.
Below is the function I am using, where init_image is x, and the masked_kspace,sens_mps,mask are the data.
def conjg_grad_correction(self,init_image,masked_kspace,sens_mps,mask):
r=self.my_rhs(masked_kspace,sens_mps,mask,init_image)
p=r
r_complex=torch.complex(r[...,0],r[...,1])
p_complex=torch.complex(p[...,0],p[...,1])
rtr=(torch.conj(r_complex)*r_complex).sum()
x_complex=torch.zeros_like(r_complex)
for i in range(10):
Ap=self.myAtAp(p,sens_mps,mask)
Ap_complex=torch.complex(Ap[...,0],Ap[...,1])
#p_complex=torch.complex(p[...,0],p[...,1])
alpha=rtr/((torch.conj(p_complex)*Ap_complex).sum())
#print(alpha)
x_complex=x_complex+alpha*p_complex
r_complex=r_complex-alpha*Ap_complex
rTrnew=(torch.conj(r_complex)*r_complex).sum()
#print(rTrnew)
beta=rTrnew/rtr
#print(beta)
p_complex=r_complex+beta*p_complex
p=torch.view_as_real(p_complex)
rtr=rTrnew
return torch.view_as_real(x_complex)
I am rather new to pytorch, so I am wondering if it is possible to reduce the memory used during training by modifying this conj grad implementation?
Thanks in advance!