Here are my attempts for your use case.
Hard-threshold capping:
import torch
loss = torch.rand(6, requires_grad=True) # lets say we have a loss tensor
torch.clip(loss, max=0.5).sum().backward() # capping loss to 0.5
# loss.grad contains 0 for loss values greater than 0.5
Ignoring top-k losses:
import torch
loss = torch.rand(6, requires_grad=True)
k = 3
sorted_loss, indices = torch.sort(loss)
sorted_loss[:-k].sum().backward()
# loss.grad contains 0 for top-k losses