Clip/limit the loss for outlier samples in a batch

Hello!

I’m working on a project where we are building a simple image classifier that we intend to deploy on a video stream. Since the images from the video are quite custom-looking we are mostly using our own datasets, and this is where the problem starts.

The labels for our datasets are generated via an automatic process with some added manual oversight. The resulting dataset contains a small amount of dirty data (unclear images where a part/nothing of the class-object is in frame). These images are hard to classify for humans, and might not be suitable to train on. Unfortunately, there is no easy way to remove these images from the dataset so I’m looking for an alternative way to handle them.

Since this dirty data is relatively rare (say 5%), I was hoping to reduce its impact on training by identifying these samples at runtime and limit the loss they produce. I’ve come up with two solutions.

Solution 1
The assumption is that the dirty data would show up as samples with high loss, and this code would limit their loss.

loss_fn = nn.CrossEntropyLoss(reduction="none")
loss = loss_fn(outputs, labels)

loss_threshold = loss.median() * 5 # Hyperparameter
loss_scales = torch.ones_like(loss)
for loss_idx, loss_val in enumerate(loss):
    if loss_val > loss_threshold:
        loss_scales[loss_idx] = loss_threshold / loss[loss_idx]

loss = loss * loss_scales # Ensures no sample-loss is greater than loss_threshold
loss = loss.mean()

Solution 2
Same idea as solution 1, but using PyTorch’s clamp function.

loss_fn = nn.CrossEntropyLoss(reduction="none")
loss = loss_fn(outputs, labels)

loss_threshold = loss.median() * 5 # Hyperparameter
loss = loss.clamp(max=loss_threshold.item())
loss = loss.mean()

It is not obvious to me whether these solutions are good ideas to try. For example, I don’t quite understand what happens to the gradients when you limit/clamp the loss as in these solutions (some info here torch.clamp kills gradients at the border · Issue #7002 · pytorch/pytorch · GitHub).

Based on a few experiments, these solutions perform worse than baseline. One thing that actually work is to limit the gradients via the pytorch function torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)

Any suggestions to solve my problem would be appreciated :sun_with_face:

Edit: Found some other threads that talk about this. No conclusive answer yet.

Hi,

One possible way to deal with this might be to initialize the learning rate to be lower. This will prevent any outlier from having profound affects on the model weights.

I also suggest a learning rate scheduler like Cosine Annealing. This is because once the model starts to converge you don’t want an outlier to make changes to your model that will slow down convergence. Therefore, a learning rate scheduler can lower the learning rate gradually as epochs progress which could help.

Sarthak Jain

Hi @SarthakJain

Thanks for your input! I’m already doing linear learning rate warmup and cosine annealing to decrease my learning rate.

What I’m worried about is that my dirty data of 5% produces much more than 5% of the loss, hence affecting the model loss-landscape/trajectory more than I’d like (even if the learning rate is low).

Bump for the americans who might have missed it for July 4th