GradCAM while training break gradients flow

Hi,

I’m trying to manipulate the explainations of a network, using a custom loss that minimize the distance between the Class Activation Maps (CAMs) and a given grayscale image (in this case, a simple matrix of ones for testing). Substancially the idea is to manipulate the explainability maps of the network, as shown in this medium link.

The code I’m using for generating the CAMs is the following:

def get_extractor(model, cam_name, target_layer):

if cam_name != "GradCAM":
    raise ValueError("Al momento supportiamo solo GradCAM.")

layer = eval(f"model.{target_layer}")

extractor = {
    'features': None,
    'gradients': None,
    'handles': []  # Lista per salvare gli hook handles
}

def forward_hook(module, input, output):
    extractor['features'] = output  # Salva le feature map
    return None

def backward_hook(module, grad_in, grad_out):
    extractor['gradients'] = grad_out[0]  # Salva i gradienti
    return None

extractor['handles'].append(layer.register_forward_hook(forward_hook))
extractor['handles'].append(layer.register_backward_hook(backward_hook))

def remove_hooks():
    for handle in extractor['handles']:
        handle.remove()
    extractor['handles'] = []  # Svuota la lista degli handles

extractor['remove_hooks'] = remove_hooks  # Aggiungi il metodo al dizionario

return extractor

def cam_extractor_fn(model, extractor, inputs, verbose=False, dont_normalize=False):

model.eval()
inputs.requires_grad = True  # Traccia i gradienti per gli input

logits = model(inputs)
if verbose:
    print(f"Logits shape: {logits.shape}")

one_hot = torch.zeros_like(logits)
target_indices = logits.argmax(dim=1)
one_hot.scatter_(1, target_indices.unsqueeze(1), 1)

logits.backward(gradient=one_hot, retain_graph=True)

feature_maps = extractor['features']
gradients = extractor['gradients']

if verbose:
    # Debug: Verifica delle feature map e dei gradienti
    print(f"Features mean: {feature_maps.mean().item()}, std: {feature_maps.std().item()}")
    print(f"Gradient mean: {gradients.mean().item()}, std: {gradients.std().item()}")

# Calcolo delle CAM
weights = gradients.mean(dim=(2, 3), keepdim=True)  
cam = (weights * feature_maps).sum(dim=1)

if not dont_normalize:
    cam_min = cam.amin(dim=(1, 2), keepdim=True)
    cam = cam - cam_min
    cam_max = cam.amax(dim=(1, 2), keepdim=True) + 1e-5
    cam = cam / cam_max

if verbose:
    print(f"CAM min: {cam.min().item()}, max: {cam.max().item()}")
return cam

And if tested singularly it works correctly giving the right outputs. The main training cycle follows this logic:

from cam_for_dist import get_extractor, cam_extractor_fn
cam_name = "GradCAM"
target_layer = "model.layer4"
assert net.model.layer4 is not None, "The model must have a layer4 attribute"
extractor = get_extractor(net, cam_name, target_layer)
mse_loss = nn.MSELoss()


for epoch in range(epochs):  
    correct_top1 = 0
    running_loss = 0.0  
    correct_top1_val = 0
    running_loss_val = 0.0

    net.train()
    idx = 0
    for inputs, labels in trainloader:

        inputs, labels = inputs.to(device), labels.to(device)
        inputs.requires_grad = True
        net.to(device)
        optimizer.zero_grad()


        outputs = net(inputs)
        loss = criterion(outputs, labels)
        

        net.eval()  
        inputs.requires_grad = True
        cam = cam_extractor_fn(net, extractor, inputs, verbose=False, dont_normalize = False)
        inputs.requires_grad = False
        net.train()  

        cam_target = torch.ones_like(cam)  
        cam_loss = mse_loss(cam, cam_target)

        loss =  loss + loss_cam_weight * cam_loss
        _, predicted = torch.max(outputs, 1)
        correct_top1 += (predicted == labels).sum().item()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

but the issue is that evaluating

        cam = cam_extractor_fn(net, extractor, inputs, verbose=False, dont_normalize = False)

during the training breaks the gradients flow and the convergence. Even if the cam_loss is removed from the loss calculation. The training without the cam_extractor_fn works as expected and reaches convergence. Any ideas on what to do? Any help is kindly welcome.