When doing SGD, we can split up the update rule to each parameter tensor in a for loop. Hence, for every layer in a feedforward neural network, we would update weights and biases.

However, I want to implement a different version of SGD where only the k parameters corresponding to the largest gradient entries are actually updated by SGD. The crucial thing here: These should be the k maximal gradients over the entire model parameters. If k=1 then this would mean we only want to update the parameter in the network corresponding to the largest gradient entry, and not to the largest entry in each layer.

How would I find the maximal elements of the entire gradient vector efficiently?
Thanks a lot!

Thanks for the fast reply. Thats what I am doing now, but I thought there might be a more efficient way. I tried to use the torch.nn.utils.parameters_to_vector function, however it only works with Parameters. Given my solution from below, can this be done more efficiently? My idea was to create a generator once and then in every layer-step of SGD I yield the relevant entries of the mask to multiply the gradient with. Not sure if this is a good approach, it is quite slow compared to SGD. I call the function with the following list param_list = [p for group in self.param_groups for p in group['params'] if p.grad is not None].

def get_mask_generator(self, param_list):
"""Generator for topk mask"""
# Get the vector
grad_vector = torch.cat([torch.abs(p.grad).view(-1) for p in param_list])
grad_vector_shape = grad_vector.shape
device = grad_vector.device
top_indices = torch.topk(grad_vector, k=self.Q).indices
del grad_vector
mask_vector = torch.zeros(grad_vector_shape, device=device)
mask_vector[top_indices] = 1
# Define the generator (note: the above code is called only once)
for p in param_list:
numEl = p.numel()
partial_mask = mask_vector[:numEl]
mask_vector = mask_vector[numEl:]
yield partial_mask.view(p.shape)

Well parameters_to_vector is doing the same thing as your cat operator. So that will be as efficient.
I don’t think you can do anything better really and manual bookkeeping of the values will most likely end up being more expensive that these few large ops.

I am not sure why you need this to be a generator though as you can directly update all the p.grad inplace from this function no?