Hi, I am using the following code for calculating the gradcam:
class GradCAM(BaseCAM):
def __init__(self, model, target_layer="module.layer4.2"):
super().__init__(model, target_layer)
def forward(self, x, class_idx=None, retain_graph=False):
b, c, h, w = x.size()
# predication on raw x
logit = self.model(x)
if class_idx is None:
score = logit[:, logit.max(1)[-1]].squeeze()
else:
score = logit[:, class_idx].squeeze()
self.model.zero_grad()
score.backward(retain_graph=retain_graph)
gradients = self.gradients['value'].data
activations = self.activations['value'].data
b, k, u, v = activations.size()
alpha = gradients.view(b, k, -1).mean(2)
weights = alpha.view(b, k, 1, 1)
saliency_map = (weights * activations).sum(1, keepdim=True)
saliency_map = F.relu(saliency_map)
saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
return saliency_map
def __call__(self, x, class_idx=None, retain_graph=False):
return self.forward(x, class_idx, retain_graph)
The aforementioned code only works when the batch size b
is equal to one. I have modified the code to support the case where b >=1
. I have used a for loop for calculating each input’s corresponding layer gradient separately. This code works fine, but it becomes slower when the batch size increase. How can I solve this problem?
class GradCAM(BaseCAM):
def __init__(self, model, target_layer="module.layer4.2"):
super().__init__(model, target_layer)
def forward(self, x, class_idx=None, retain_graph=False):
b, c, h, w = x.size()
# predication on raw x
logit = self.model(x)
softmax = F.softmax(logit, dim=1)
if class_idx is None:
score = logit[:, logit.max(1)[-1]]
else:
score = logit[:, class_idx]
if b > 1:
retain_graph = True
self.model.zero_grad()
gradients_list = []
for i, item in enumerate(score):
item.backward(retain_graph=retain_graph)
gradients = self.gradients['value'].data[i]
gradients_list.append(gradients)
gradients = torch.stack(gradients_list, dim=0)
activations = self.activations['value'].data
b, k, u, v = activations.size()
alpha = gradients.view(b, k, -1).mean(2)
weights = alpha.view(b, k, 1, 1)
saliency_map = (weights * activations).sum(1, keepdim=True)
saliency_map = F.relu(saliency_map)
saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
saliency_map_shape = saliency_map.shape
saliency_map = saliency_map.view(saliency_map.shape[0], -1)
saliency_map_min, saliency_map_max = saliency_map.min(1, keepdim=True)[0], saliency_map.max(1, keepdim=True)[0]
saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
saliency_map = saliency_map.view(saliency_map_shape)
return saliency_map, softmax.detach()