Online Knowledge Distillation

Dear All,

I hope you all are doing well and safe in a current COVID-19 Pandemic. I am very new to computer vision and currently working on a knowledge distillation problem. I am implementing knowledge distillation using mutual learning, in which four CNN models (M1, M2, M3, M4) are trained simultaneously. All models are running and computing the loss. The losses of all four models are added and the total sum is backpropagated to update the weights of the four models. I want to create a new CNN model M5 that takes feature maps (just before the FC layers of M2 and M3) of models M2 and M3 as an input. The model M5 has one CNN layer, a Batch Norm Layer, an fully connected layer, and an output layer.
Could you please help me to know how to extract the feature maps from models M2 and M3 and forward them to model M5 during training so that now I can train all five models simultaneously?. The models M1, M2, M3, M4 are using the pre-trained Resnet-34 model with the number of classes=10. An example demo code will be highly helpful that illustrates how model M5 takes inputs from models M2 and M3 and forwarding them.


You could register forward hooks in M2 and M3 in the penultimate layer as described here.

During the forward pass of these models the hooks will be activated and the outputs will be stored e.g. in a dict. Once this is done you can pass the activations from the dict to M5 and continue the training.

Depending if you want to calculate the gradients of the loss w.r.t. M2 and M3 you could either store the intermediate activations directly or detach() them in the forward hook.

Thanks a lot ptrblck