Loss function not back propogating

I made the following loss function to attempt to penalise the highest gradients out of the input. Everything runs fine to begin with however I notice that the extra loss function isn’t achieving anything. So I check to see if the gradient changes in the input after I call grad_loss.backward() and it doesn’t change at all.

Loss function:

class LInfLoss:
    def __init__(self, weight=1, tolerance=0.001):
        self.weight = torch.tensor(weight).float()
        self.weight.requires_grad = True
        self.weight = self.weight.to('cuda:0')
        self.tolerance = tolerance

    def __call__(self, gradient: torch.Tensor):
        threshold = gradient.max() * (1 - self.tolerance)
        grads_over_threshold = gradient.ge(threshold).float() * self.weight
        return grads_over_threshold.sum()

Training loop

    for (images, one_hot_labels) in tqdm(batched_train_data):
        # I collect batch size here because the last batch may have a smaller batch_size
        batch_size = images.shape[0]
        images.requires_grad = True
        
        optimizer.zero_grad()
        # as images is not a parameters optimizer.zero_grad() won't reset it's gradient
        if images.grad is not None:
            images.grad = None

        probabilities = model.forward(images)

        # I want to use .backward() twice rather than autograd because I want to accumulate the gradients
        loss = loss_func(probabilities, one_hot_labels)
        loss.backward(create_graph=True)
        grad_loss = grad_loss_func(images.grad)
        grad_loss.backward()

        optimizer.step()

        labels = one_hot_labels.detach().argmax(dim=1)
        predictions = probabilities.detach().argmax(dim=1)
        num_correct = int(predictions.eq(labels).sum())

        train_data_length += batch_size
        train_correct += num_correct
        train_loss += float(loss.detach()) * batch_size

        # To stop memory leaks
        images.grad = None
        del probabilities
        del loss
        del grad_loss
        del labels
        del predictions
        del num_correct

EDIT:
just to clarify grad_loss_func = LInfLoss