How to calucate GradCam efficiently when batch size is greater than one

Hi, I am using the following code for calculating the gradcam:

class GradCAM(BaseCAM):
    def __init__(self, model, target_layer="module.layer4.2"):
        super().__init__(model, target_layer)

    def forward(self, x, class_idx=None, retain_graph=False):
        b, c, h, w = x.size()

        # predication on raw x
        logit = self.model(x)

        if class_idx is None:
            score = logit[:, logit.max(1)[-1]].squeeze()
        else:
            score = logit[:, class_idx].squeeze()

        self.model.zero_grad()
        score.backward(retain_graph=retain_graph)
        gradients = self.gradients['value'].data
        activations = self.activations['value'].data
        b, k, u, v = activations.size()

        alpha = gradients.view(b, k, -1).mean(2)
        weights = alpha.view(b, k, 1, 1)
        saliency_map = (weights * activations).sum(1, keepdim=True)

        saliency_map = F.relu(saliency_map)
        saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
        saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
        saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
        return saliency_map

    def __call__(self, x, class_idx=None, retain_graph=False):
        return self.forward(x, class_idx, retain_graph)

The aforementioned code only works when the batch size b is equal to one. I have modified the code to support the case where b >=1. I have used a for loop for calculating each input’s corresponding layer gradient separately. This code works fine, but it becomes slower when the batch size increase. How can I solve this problem?

class GradCAM(BaseCAM):
    def __init__(self, model, target_layer="module.layer4.2"):
        super().__init__(model, target_layer)

    def forward(self, x, class_idx=None, retain_graph=False):
        b, c, h, w = x.size()

        # predication on raw x
        logit = self.model(x)
        softmax = F.softmax(logit, dim=1)

        if class_idx is None:
            score = logit[:, logit.max(1)[-1]]
        else:
            score = logit[:, class_idx]

        if b > 1:
            retain_graph = True

        self.model.zero_grad()
        gradients_list = []
        for i, item in enumerate(score):
            item.backward(retain_graph=retain_graph)
            gradients = self.gradients['value'].data[i]
            gradients_list.append(gradients)

        gradients = torch.stack(gradients_list, dim=0)
        activations = self.activations['value'].data
        b, k, u, v = activations.size()

        alpha = gradients.view(b, k, -1).mean(2)
        weights = alpha.view(b, k, 1, 1)
        saliency_map = (weights * activations).sum(1, keepdim=True)

        saliency_map = F.relu(saliency_map)
        saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)

        saliency_map_shape = saliency_map.shape
        saliency_map = saliency_map.view(saliency_map.shape[0], -1)
        saliency_map_min, saliency_map_max = saliency_map.min(1, keepdim=True)[0], saliency_map.max(1, keepdim=True)[0]
        saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
        saliency_map = saliency_map.view(saliency_map_shape)
        return saliency_map, softmax.detach()