I have built a text classifier where I’m trying to assign 9 different labels to each data sample where one of these labels are the “none” label.
The output of my network is a Softmax layer that contains probabilities for the 9 labels. When I’m calculating metrics I am thresholding this tensor (> 0.5 => 1.0).
Now, especially in the beginning of the training most of the probabilities are below the threshold so what I did was to “clamp” the “confused” vectors to have a really small probability for the 8 actual labels and most of the probability mass is assigned to the “none” label.
So for example, a confused tensor could be
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.1, 0.1]
what I’m doing this tensor is I turn it into
[1e-3, 1-e3, 1e-3, 1e-3, 1e-3, 1-e3, 1e-3, 1e-3, 0.992] (this is what I’m calling clamping).
After some epochs this clamping goes down to around be around 0.001 for the total datapoints.
Anyhow, my implementation of this “manual clamping layer” is rather naive and I would love if someone could come up with a better implementation that might be faster to compute!
def clamp_layer(self, log_probs, probs, threshold = 0.5): clamp_counter = 0 modified_log_probs = torch.zeros_like(log_probs).float() modified_probs = torch.zeros_like(probs).float() ## 99.2 % prob for none relation margin, n_classes = 1e-3, 8 non_vector = torch.cat((torch.full((n_classes,), margin), torch.tensor([1.0 - margin * n_classes]))) non_log_vector = torch.log(non_vector) batch_size = log_probs.size(0) for batch_idx in range(batch_size): no_relation_above_threshold = torch.ge(probs[batch_idx], threshold).sum() == 0 if(no_relation_above_threshold): modified_probs[batch_idx] = non_vector modified_log_probs[batch_idx] = non_log_vector clamp_counter += 1 else: modified_probs[batch_idx] = probs[batch_idx] modified_log_probs[batch_idx] = log_probs[batch_idx] modified_log_probs.requires_grad_(True) modified_probs.requires_grad_(True) return modified_log_probs, modified_probs, clamp_counter
My second question is, how do one handle softmax outputs that are all below a certain threshold? is this the way to go or is there some papers or theories that discuss this issue or am I doing something completely sketchy and this should not happen ?