Batch Optimization of "Input Samples"

Hello!

I have made an implementation that optimizes input vector samples for a given loss function. With single or few input vector samples it works perfect with Adam. However, when I place 10K samples in a batch and take the mean of their individual losses as an optimization target, the optimization gets very slow and hence inefficient. (The speed of each iteration does not change much but the input vectors get optimized in much smaller steps. I have tried to increase the learning-rate but that did not help much) I guess that it is simply unable to properly figure out which input samples contributes how much to the increase or reduction in average loss with Auto-grad. I have tried to increase the learning rate but that did not help much. Did anybody else here previously worked in such problem? What would be the best way to tackle it? I would be more than glad if anybody can help. Thank you in advance for your time and consideration.

Note: I would be also fine if the learning goes fast in a minority of samples that are the best candidates by making the optimizer to focus into them most. In that case, would it for example make sense to optimize the loss for a percentile loss, rather than the average loss from all samples? (e.g. optimizing the worse candidates can result in better exploration but the exploitation would be much worse in that case).

Note: I am not working in that space but I would assume that adversarial attack generation to a deep learning model can be one example use-case where such inputs require to be optimized / generated.

Sincerely,
Kamer

I found an initial solution to the problem, which is re-scaling the gradients of each input sample so that the norm of the smallest sample is one. However, to keep using a regular optimizer, I need to figure out how that effects the momentum of the optimizer, etc. At the moment, I am just updating the input samples by adding them their re-scaled gradients. I am looking forward for any suggestions from more experienced.