So , i have been trying to implement the distilled model concept by Hinton et. al. in the paper Hinton Dark Knowledge.
Accordingly i trained a cumbersome model , and depending on the results on the cumbersome model, i have to train the smaller model to fit the data.
Now the output from the big cumbersome model is of the shape ( batch_size , outputs ) which is same as the as the output size from the small model.
Now in the paper, it is mentioned to take a cross-entropy loss between the outputs from the big model which will act as soft targets for the small model. However neither the size nor the data type ( FloatTensor ) is suitable for cross-entropy loss in the pytorch library . I think there’s api available for the same in tensorflow.
I have tried KL Divergence loss for the loss between outputs of the big model and the small one and the CrossEntropyLoss for the outputs of the small model and the actual targets ( hard labels), but it gave me very poor generalizing errors, it seems to overfit very much.
How can i implement a custom loss function for the outputs from the big model ( soft targets ) and the outputs from the big model which is of the same size .
Yes, pytorch’s cross_entropy_loss() is a special case of cross-entropy
that requires integer categorical labels (“hard targets”) for its targets.
(It also takes logits, rather than probabilities, for its predictions.)
It does sound like you want a general cross-entropy loss that takes
probabilities (“soft tagets”) for its targets. This general version is not
built in to pytorch.
But you can implement the general version using pytorch tensor
operations. See this earlier thread:
Note that the softXEnt() implemented in this post also takes logits
for its predictions. If your use case requires you to pass in probabilities
for your predictions (less numerically stable), you will have to modify softXEnt() accordingly.
Yes, i think this might solve my problem.
I have to pass direct probabilies generated from a model ( the big one ) and after doing softmax on it to the loss function , where both are of the same size .