I made the following loss function to attempt to penalise the highest gradients out of the input. Everything runs fine to begin with however I notice that the extra loss function isn’t achieving anything. So I check to see if the gradient changes in the input after I call grad_loss.backward()
and it doesn’t change at all.
Loss function:
class LInfLoss:
def __init__(self, weight=1, tolerance=0.001):
self.weight = torch.tensor(weight).float()
self.weight.requires_grad = True
self.weight = self.weight.to('cuda:0')
self.tolerance = tolerance
def __call__(self, gradient: torch.Tensor):
threshold = gradient.max() * (1 - self.tolerance)
grads_over_threshold = gradient.ge(threshold).float() * self.weight
return grads_over_threshold.sum()
Training loop
for (images, one_hot_labels) in tqdm(batched_train_data):
# I collect batch size here because the last batch may have a smaller batch_size
batch_size = images.shape[0]
images.requires_grad = True
optimizer.zero_grad()
# as images is not a parameters optimizer.zero_grad() won't reset it's gradient
if images.grad is not None:
images.grad = None
probabilities = model.forward(images)
# I want to use .backward() twice rather than autograd because I want to accumulate the gradients
loss = loss_func(probabilities, one_hot_labels)
loss.backward(create_graph=True)
grad_loss = grad_loss_func(images.grad)
grad_loss.backward()
optimizer.step()
labels = one_hot_labels.detach().argmax(dim=1)
predictions = probabilities.detach().argmax(dim=1)
num_correct = int(predictions.eq(labels).sum())
train_data_length += batch_size
train_correct += num_correct
train_loss += float(loss.detach()) * batch_size
# To stop memory leaks
images.grad = None
del probabilities
del loss
del grad_loss
del labels
del predictions
del num_correct
EDIT:
just to clarify grad_loss_func
= LInfLoss