I’m working on a project where we are building a simple image classifier that we intend to deploy on a video stream. Since the images from the video are quite custom-looking we are mostly using our own datasets, and this is where the problem starts.
The labels for our datasets are generated via an automatic process with some added manual oversight. The resulting dataset contains a small amount of dirty data (unclear images where a part/nothing of the class-object is in frame). These images are hard to classify for humans, and might not be suitable to train on. Unfortunately, there is no easy way to remove these images from the dataset so I’m looking for an alternative way to handle them.
Since this dirty data is relatively rare (say 5%), I was hoping to reduce its impact on training by identifying these samples at runtime and limit the loss they produce. I’ve come up with two solutions.
The assumption is that the dirty data would show up as samples with high loss, and this code would limit their loss.
loss_fn = nn.CrossEntropyLoss(reduction="none") loss = loss_fn(outputs, labels) loss_threshold = loss.median() * 5 # Hyperparameter loss_scales = torch.ones_like(loss) for loss_idx, loss_val in enumerate(loss): if loss_val > loss_threshold: loss_scales[loss_idx] = loss_threshold / loss[loss_idx] loss = loss * loss_scales # Ensures no sample-loss is greater than loss_threshold loss = loss.mean()
Same idea as solution 1, but using PyTorch’s clamp function.
loss_fn = nn.CrossEntropyLoss(reduction="none") loss = loss_fn(outputs, labels) loss_threshold = loss.median() * 5 # Hyperparameter loss = loss.clamp(max=loss_threshold.item()) loss = loss.mean()
It is not obvious to me whether these solutions are good ideas to try. For example, I don’t quite understand what happens to the gradients when you limit/clamp the loss as in these solutions (some info here torch.clamp kills gradients at the border · Issue #7002 · pytorch/pytorch · GitHub).
Based on a few experiments, these solutions perform worse than baseline. One thing that actually work is to limit the gradients via the pytorch function
Any suggestions to solve my problem would be appreciated
Edit: Found some other threads that talk about this. No conclusive answer yet.