Optimize a module with parameters per datapoint

Hi,

I have a nn.Module that contains a single tensor of size n x k where n is my dataset length. Each row represents some parameters that I need to tune for each specific data point.
I want to optimize over the parameters of this module (along with another module which does not scale with the data).
I use Adam optimizer, and on each epoch there is a single batch in which the gradient of the i-th row is not zero.
However on subsequent batches, the optimizer keeps modifying the i-th row despite it having a gradient of 0 (since Adam as well as SGD use some ‘memory’ of previous gradients).

What would be the best approach to prevent this from happening? What I would like to see is for the optimizer to only interpret a ‘step’ of the i-th row if the i-th data point is present in the batch. Is there an elegant solution for a problem like that?

Thanks!

Adam use statistics precomputed. So it has certain inertia (let’s call it like that).
To avoid it, you can create n nn.Paramters instead of putting all together in a single matrix.
Alternatively, you can use plain SGD with no momentum.