Is it possible to optimize only selected part of the tensor in each iteration?

Hello!
Is there a way to optimiize only a certain part of a tensor?
I am truing to use Pythorch as a backend for ptychographycal engine, which requires to optimize large [2^12,2^12,2] tensor O. But I can formulate loss function only for set of the overlaping slices of this big tensor O[up:down,left:right,:] so I need to calculate updates for each of these slices consistently. Currently, I just iterate through list of these slices and use optimizer with respect to the whole big tensor which require to calclate gradient for the whole O each step which takes most of the time, despite I know that only small part of it O[up:down,left:right,:] influence on the loss function and should be updated this step.

So if there any way to turn on gradient calculation only for some part of O each step, or are there any other ways to solve this problem?

How are you calculating the gradients at the moment?
Could you post some dummy code, as I’m not sure how the loss is being calculated at the moment?

Currently I utilize the following strategy:
Initially I have apriory known Propagation_model,Measured_th and Slice_coordinates where:
Propagation_model - model object, which forward method transform my guessed input data into guessed approximation of measured data. This method may include several FFT and multiplications with different kernels, but it is known percicely in advance and does not require any optimization, so Propagation_model does not require gradient.
Measured_th - [n,1024,1024,2] complex tensor which represents n measurments obtained during experiment and also requires no optimization.
Slice_coordinates - list of tuples of coordinates [(up,down,left,right)] - which allows to get part of the (Sample_th - tensor which represent reconstructed object ) corresponding to exact measurment from Measured_th

During the optimizational procedure I trying to obtain two complex tensors Sample_th [4096,4096,2] and Probe_th [1024,1024,2], both of them requre to be optimized and require gradient.

My loss defined as sumed squared difference between certain slice of ‘Measured_data’ and result of propagation of certain slice of Sample_th and Probe_th

Before the begining of optimization loop I create set of slices Slices_th from Sample_th, since only a part [1024,1024,2] of Sample_th with (up,down,left,right) coordinates takes part in production of certain Measured_th slice:
Slices_th= [] for num in range(len(Measured_th)): Slices_th.append((Sample_th[Borders.up[num]: Borders.down[num], Borders.left[num]: Borders.right[num],:]))

all members of Slices_th represenc slices taken from Sample_th and partially overlap each other ( certain part of Sample_th belong to multiple memders of Slices_th simultaneously.

I am trying to optimize Sample_th and Probe_th with corresponding optimizer:

optimizer = torch.optim.Adam([ {'params': Probe_th, 'lr': 0.5e-1,'weight_decay':0.004}, {'params': Sample_th, 'lr':0.5e-1,'weight_decay':0.04}, ])

In the folllowing loop:


for i in range(epoch_num):
    nums = list(range(len(measured_abs)))
    np.random.shuffle(nums)
    for num in nums:
        optimizer.zero_grad()
        Obj_th = Slices_th[num]
        Meas = Measured_th[num]
        loss = Sq_err(Propagation_model.forward(probe = Probe_th,obj = Obj_th,),Meas)
        err.append(float(loss.cpu()))
        loss.backward()
        optimizer.step()
        
    print(np.mean(err),'---',i)
    err_long.append(np.mean(err))

Currently, main probem is that during each optimizer.step() gradient calculated for the whole Sample_th which takes most of the time, despite i know that only a small part of it corresponding to current slice participated in loss calculation and should be optimized during this step. From the other hand, I can’t separate Slices_th into independent arrays, since I should take into account their mutual overlap, so change of one of them during optimizer.step() should be spreaded among several of them correspondingly.

Sorry for such long explanation, I just don’t know how to explain it shorter)