I am having difficulty extracting outputs of intermediate layers for a student and (fixed) teacher model in distillation, and then taking the MSE of these as the training loss. For a start, PyTorch makes it very difficult to extract intermediate layer outputs at all. You either have to edit the forward function of the model or use a forward hook. I found that the latter method (hook) does not work because the output of the layer is not a leaf tensor and cannot require grad. I resorted to using the package torch_intermediate_layer_getter which seems to return an OrderedDict of tensors of layers you specify. Importantly, these come with grad:
def feature_extractor(model, inputs, return_layers, grad=True):
"""Extract feature maps from model.
Args:
model: torch model, model to extract feature maps from
inputs: torch tensor, input to model
return_layers: dictionary of layer names to return
"""
assert inputs.requires_grad
mid_getter = MidGet(model, return_layers, True)
mid_outputs, model_outputs = mid_getter(inputs)
features = list(mid_outputs.items())[0][1]
if grad == False:
features = features.detach()
return features
assert features.requires_grad
return features
When computing the rest of my loss function, I always check that variables require grad:
def feature_map_diff(s_map, t_map, aggregate_chan):
"""Compute the difference between the feature maps of the student and teacher models.
Args:
s_map: torch tensor, activation map of teacher model [batch_size, num_channels]
t_map: torch tensor, output of teacher model [batch_size, num_channels]
aggregate_chan: bool, whether to aggregate the channels of the feature activation
"""
# Aggregate the channels of the feature activation using root squared absolute value of channels to create activation map
if aggregate_chan:
s_map = torch.sqrt(torch.sum(torch.abs(s_map)**2, dim=1))
t_map = torch.sqrt(torch.sum(torch.abs(t_map)**2, dim=1))
assert s_map.requires_grad
# Compute the difference between the feature maps
# loss = F.mse_loss(s_map, t_map, reduction='mean').requires_grad_()
loss = torch.mean((s_map/torch.norm(s_map, p=2, dim=-1).unsqueeze(-1) - t_map/torch.norm(t_map, p=2, dim=-1).unsqueeze(-1))**2 )
assert loss.requires_grad
return loss
None of these assertions raise any error, but in my main training loop, with the following after loss.backward()
and optimizer.step()
:
for param in student.parameters():
assert param.requires_grad
I end up with an error. Can someone explain to me what is going on and how I can achieve what I’m trying to do better?