Hi, I’m exploring distillation and similar concepts on CIFAR-10 with a PyTorch implementation of wide-resnet!
I’m looking at this repo in particular: https://github.com/meliketoy/wide-resnet.pytorch
and the wide-resnet file is: https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
I need to get the hold of intermediate results like the “out” in here:
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
I want to force the “out” on certain layer to be similar to the “out” value outputted by a teacher network. Thus my loss looks like Cross_entropy_loss + ||out1 - out2||. (the corss entropy loss of the student network, plus some penalty form to force output to be similar to teacher network)
Is there a way for me to accomplish this by writing my custom loss function?