Hi,
I’m trying to manipulate the explainations of a network, using a custom loss that minimize the distance between the Class Activation Maps (CAMs) and a given grayscale image (in this case, a simple matrix of ones for testing). Substancially the idea is to manipulate the explainability maps of the network, as shown in this medium link.
The code I’m using for generating the CAMs is the following:
def get_extractor(model, cam_name, target_layer):
if cam_name != "GradCAM": raise ValueError("Al momento supportiamo solo GradCAM.") layer = eval(f"model.{target_layer}") extractor = { 'features': None, 'gradients': None, 'handles': [] # Lista per salvare gli hook handles } def forward_hook(module, input, output): extractor['features'] = output # Salva le feature map return None def backward_hook(module, grad_in, grad_out): extractor['gradients'] = grad_out[0] # Salva i gradienti return None extractor['handles'].append(layer.register_forward_hook(forward_hook)) extractor['handles'].append(layer.register_backward_hook(backward_hook)) def remove_hooks(): for handle in extractor['handles']: handle.remove() extractor['handles'] = [] # Svuota la lista degli handles extractor['remove_hooks'] = remove_hooks # Aggiungi il metodo al dizionario return extractor
def cam_extractor_fn(model, extractor, inputs, verbose=False, dont_normalize=False):
model.eval() inputs.requires_grad = True # Traccia i gradienti per gli input logits = model(inputs) if verbose: print(f"Logits shape: {logits.shape}") one_hot = torch.zeros_like(logits) target_indices = logits.argmax(dim=1) one_hot.scatter_(1, target_indices.unsqueeze(1), 1) logits.backward(gradient=one_hot, retain_graph=True) feature_maps = extractor['features'] gradients = extractor['gradients'] if verbose: # Debug: Verifica delle feature map e dei gradienti print(f"Features mean: {feature_maps.mean().item()}, std: {feature_maps.std().item()}") print(f"Gradient mean: {gradients.mean().item()}, std: {gradients.std().item()}") # Calcolo delle CAM weights = gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * feature_maps).sum(dim=1) if not dont_normalize: cam_min = cam.amin(dim=(1, 2), keepdim=True) cam = cam - cam_min cam_max = cam.amax(dim=(1, 2), keepdim=True) + 1e-5 cam = cam / cam_max if verbose: print(f"CAM min: {cam.min().item()}, max: {cam.max().item()}") return cam
And if tested singularly it works correctly giving the right outputs. The main training cycle follows this logic:
from cam_for_dist import get_extractor, cam_extractor_fn cam_name = "GradCAM" target_layer = "model.layer4" assert net.model.layer4 is not None, "The model must have a layer4 attribute" extractor = get_extractor(net, cam_name, target_layer) mse_loss = nn.MSELoss() for epoch in range(epochs): correct_top1 = 0 running_loss = 0.0 correct_top1_val = 0 running_loss_val = 0.0 net.train() idx = 0 for inputs, labels in trainloader: inputs, labels = inputs.to(device), labels.to(device) inputs.requires_grad = True net.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) net.eval() inputs.requires_grad = True cam = cam_extractor_fn(net, extractor, inputs, verbose=False, dont_normalize = False) inputs.requires_grad = False net.train() cam_target = torch.ones_like(cam) cam_loss = mse_loss(cam, cam_target) loss = loss + loss_cam_weight * cam_loss
_, predicted = torch.max(outputs, 1) correct_top1 += (predicted == labels).sum().item() loss.backward() optimizer.step() running_loss += loss.item()
but the issue is that evaluating
cam = cam_extractor_fn(net, extractor, inputs, verbose=False, dont_normalize = False)
during the training breaks the gradients flow and the convergence. Even if the cam_loss is removed from the loss calculation. The training without the cam_extractor_fn works as expected and reaches convergence. Any ideas on what to do? Any help is kindly welcome.