Treating data/network weights differently

Hi, I encountered some issues with implementing a simple idea: my goal is to determine which data samples in a batch / which features in a feature space are the most active (during the optimization step). I am aware of some other ways of reducing the impact of particular data samples on the optimization trajectory (such as gradient clipping) but I want to have more control over it.

For now we can I assume that I am the most interested in determining which neurons in Lenet (3 CNNs and 2 FCs) have the largest gradient update and treat them differently (either by dynamic selection of parameters for the optimizer or using different optimizers). Similarly, I would like to treat different images within the batch differently - for instance by selecting the individual learning rate LR_Base * some_metric(image).

Finally, I want to do it efficiently. So instead of looping like below, I would love to use vmap.

def compute_loss_for_single_instance(network, loss_function, image, label):
    y_pred = network(image.unsqueeze(0))
    loss = loss_function(y_pred, label.unsqueeze(0))
    return loss
for (X, y) in iterate_dataset(dataset, batch_size):
    # call compute_loss_for_single_instance and ... 

And for the network parameter case, instead of creating every time when new data arrives, few optimizers with grouped weights, ideally I would love to have one. Like here: Different learning rate for a specific layer but not in a structure way like layer-by-layer.

So let’s dive into the example. Assume I have model [Lenet] and dataset [MNIST: img,label] and chosen loss function (like MSE). I set up the optimizer optim initially on all parameters. Now I want to establish the metric which indicates when the parameters are starting to change significantly, so for instance that would be gradient of the norm of the gradient.

For data I figured out something like:

def compute_grad_norm(grads):
    grads = [param_grad.detach().flatten() for param_grad in grads if param_grad is not None]
    norm = torch.cat(grads).norm()
    return norm

for (X, y) in iterate_dataset(train_dataset):
    vmap_loss = torch.vmap(compute_loss_for_single_instance, in_dims=(None, None, 0, 0))
    losses = vmap_loss(network, loss_fn, X, y)
    norm_gradients = [compute_grad_norm(torch.autograd.grad(loss, network.parameters(), retain_graph=True)).cpu().numpy() for loss in losses]

(1) From now I could keep track of this metric externally and observe the pattern. But ideally I would like to modify the code above to measure the gradient of the norm of the gradient.
(2) How to proceed from that, i.e. I don’t want to split the data batch into two categories (i.e. comparing to some threshold) and have two different optimizers, but I would like to have some smooth way of dealing with that by “telling” the optimizer to apply various learning rate.
(3) How to project this way of thinking into the case with weight selection, i.e. assuming I compute some metrics (vector of the same length as the total number of parameters in the network) how to “tell” my optimizer that I want to do gradient step differently, ideally not by using loop over all parameters as it is quite costly.

Many thanks for your help!