Access intermediate tensors of a computation

Hi, I’m exploring distillation and similar concepts on CIFAR-10 with a PyTorch implementation of wide-resnet!
I’m looking at this repo in particular: https://github.com/meliketoy/wide-resnet.pytorch
and the wide-resnet file is: https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py

I need to get the hold of intermediate results like the “out” in here:

def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)

        return out

I want to force the “out” on certain layer to be similar to the “out” value outputted by a teacher network. Thus my loss looks like Cross_entropy_loss + ||out1 - out2||. (the corss entropy loss of the student network, plus some penalty form to force output to be similar to teacher network)

Is there a way for me to accomplish this by writing my custom loss function?

yes, you can operate on particular out variables, and return them.

For example:

def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.layer1(out1)
        out3 = self.layer2(out2)
        out = self.layer3(out3)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)


        norm = (out3 - out2).norm(2)
    
        return out, norm

out, norm = model(input)
cross_entropy = criterion(out, target)
(cross_entropy + norm).backward()
1 Like