Difference in attention maps as loss function

I am having difficulty extracting outputs of intermediate layers for a student and (fixed) teacher model in distillation, and then taking the MSE of these as the training loss. For a start, PyTorch makes it very difficult to extract intermediate layer outputs at all. You either have to edit the forward function of the model or use a forward hook. I found that the latter method (hook) does not work because the output of the layer is not a leaf tensor and cannot require grad. I resorted to using the package torch_intermediate_layer_getter which seems to return an OrderedDict of tensors of layers you specify. Importantly, these come with grad:

def feature_extractor(model, inputs, return_layers, grad=True):
    """Extract feature maps from model.
    Args:
        model: torch model, model to extract feature maps from
        inputs: torch tensor, input to model
        return_layers: dictionary of layer names to return
    """
    assert inputs.requires_grad
    mid_getter = MidGet(model, return_layers, True)
    mid_outputs, model_outputs = mid_getter(inputs)
    features =  list(mid_outputs.items())[0][1]
    if grad == False:
        features = features.detach()
        return features
    assert features.requires_grad
    return features

When computing the rest of my loss function, I always check that variables require grad:

def feature_map_diff(s_map, t_map, aggregate_chan):
    """Compute the difference between the feature maps of the student and teacher models.
    Args:
        s_map: torch tensor, activation map of teacher model [batch_size, num_channels]
        t_map: torch tensor, output of teacher model [batch_size, num_channels]
        aggregate_chan: bool, whether to aggregate the channels of the feature activation
    """
    # Aggregate the channels of the feature activation using root squared absolute value of channels to create activation map
    if aggregate_chan:
        s_map = torch.sqrt(torch.sum(torch.abs(s_map)**2, dim=1))
        t_map = torch.sqrt(torch.sum(torch.abs(t_map)**2, dim=1))
    assert s_map.requires_grad
    # Compute the difference between the feature maps
    # loss = F.mse_loss(s_map, t_map, reduction='mean').requires_grad_()
    loss = torch.mean((s_map/torch.norm(s_map, p=2, dim=-1).unsqueeze(-1) - t_map/torch.norm(t_map, p=2, dim=-1).unsqueeze(-1))**2 )
    assert loss.requires_grad
    return loss

None of these assertions raise any error, but in my main training loop, with the following after loss.backward() and optimizer.step():

for param in student.parameters():
    assert param.requires_grad

I end up with an error. Can someone explain to me what is going on and how I can achieve what I’m trying to do better?