Error generating Grad-CAM: 'tuple' object has no attribute 'cpu'

help to sovle this error: Error generating Grad-CAM: ‘tuple’ object has no attribute ‘cpu’

Add this class somewhere above your Grad-CAM code

class GradCAMWrapper(nn.Module):
def init(self, model, model_name):
super().init()
self.model = model
self.model_name = model_name

def forward(self, x):
    # For ViT, output is a tuple; for hybrid, output is tensor
    out = self.model(x)
    if isinstance(out, tuple):
        return out[0]  # or out.logits if available
    if hasattr(out, "logits"):
        return out.logits
    return out

#-----------------------------

def find_first_tensor(x):
“”“Recursively find the first torch.Tensor in a nested tuple/list structure.”“”
if isinstance(x, torch.Tensor):
return x
elif isinstance(x, (tuple, list)) and len(x) > 0:
return find_first_tensor(x[0])
else:
raise ValueError(“No tensor found in Grad-CAM output”)

def visualize_gradcam(model, image, device, model_name=‘vit’):
model.eval()
wrapped_model = GradCAMWrapper(model, model_name)

if model_name == 'vit':
    target_layers = [model.vit.encoder.layer[-1]]
elif model_name == 'hybrid':
    target_layers = [model.vit.encoder.layer[-1]]
else:
    logging.error(f"Unsupported model_name: {model_name}")
    return None

try:
    cam = GradCAM(model=wrapped_model, target_layers=target_layers)
except Exception as e:
    logging.error(f"Error initializing GradCAM: {e}")
    return None

input_tensor = image.unsqueeze(0).to(device)

try:
    grayscale_cam = cam(input_tensor=input_tensor, targets=None)
    logging.info(f"Type of grayscale_cam after cam(): {type(grayscale_cam)}")
    logging.info(f"Value of grayscale_cam after cam(): {grayscale_cam}")
    grayscale_cam = find_first_tensor(grayscale_cam)  # <-- This line is key!
    grayscale_cam = grayscale_cam.detach().cpu().numpy()
    if not isinstance(grayscale_cam, np.ndarray):
        grayscale_cam = np.array(grayscale_cam)
except Exception as e:
    logging.error(f"Error generating Grad-CAM: {e}")
    return None

img = image.permute(1, 2, 0).cpu().numpy()
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)

try:
    visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
    return visualization
except Exception as e:
    logging.error(f"Error creating Grad-CAM visualization: {e}")
    return None