Currently I utilize the following strategy:
Initially I have apriory known Propagation_model
,Measured_th
and Slice_coordinates
where:
Propagation_model
- model object, which forward
method transform my guessed input data into guessed approximation of measured data. This method may include several FFT and multiplications with different kernels, but it is known percicely in advance and does not require any optimization, so Propagation_model
does not require gradient.
Measured_th
- [n,1024,1024,2]
complex tensor which represents n measurments obtained during experiment and also requires no optimization.
Slice_coordinates
- list of tuples of coordinates [(up,down,left,right)]
- which allows to get part of the (Sample_th
- tensor which represent reconstructed object ) corresponding to exact measurment from Measured_th
During the optimizational procedure I trying to obtain two complex tensors Sample_th
[4096,4096,2]
and Probe_th
[1024,1024,2]
, both of them requre to be optimized and require gradient.
My loss defined as sumed squared difference between certain slice of ‘Measured_data’ and result of propagation of certain slice of Sample_th
and Probe_th
Before the begining of optimization loop I create set of slices Slices_th
from Sample_th
, since only a part [1024,1024,2]
of Sample_th
with (up,down,left,right)
coordinates takes part in production of certain Measured_th
slice:
Slices_th= [] for num in range(len(Measured_th)): Slices_th.append((Sample_th[Borders.up[num]: Borders.down[num], Borders.left[num]: Borders.right[num],:]))
all members of Slices_th
represenc slices taken from Sample_th
and partially overlap each other ( certain part of Sample_th
belong to multiple memders of Slices_th
simultaneously.
I am trying to optimize Sample_th
and Probe_th
with corresponding optimizer:
optimizer = torch.optim.Adam([ {'params': Probe_th, 'lr': 0.5e-1,'weight_decay':0.004}, {'params': Sample_th, 'lr':0.5e-1,'weight_decay':0.04}, ])
In the folllowing loop:
for i in range(epoch_num):
nums = list(range(len(measured_abs)))
np.random.shuffle(nums)
for num in nums:
optimizer.zero_grad()
Obj_th = Slices_th[num]
Meas = Measured_th[num]
loss = Sq_err(Propagation_model.forward(probe = Probe_th,obj = Obj_th,),Meas)
err.append(float(loss.cpu()))
loss.backward()
optimizer.step()
print(np.mean(err),'---',i)
err_long.append(np.mean(err))
Currently, main probem is that during each optimizer.step()
gradient calculated for the whole Sample_th
which takes most of the time, despite i know that only a small part of it corresponding to current slice participated in loss calculation and should be optimized during this step. From the other hand, I can’t separate Slices_th
into independent arrays, since I should take into account their mutual overlap, so change of one of them during optimizer.step()
should be spreaded among several of them correspondingly.
Sorry for such long explanation, I just don’t know how to explain it shorter)