help to sovle this error: Error generating Grad-CAM: ‘tuple’ object has no attribute ‘cpu’
Add this class somewhere above your Grad-CAM code
class GradCAMWrapper(nn.Module):
def init(self, model, model_name):
super().init()
self.model = model
self.model_name = model_name
def forward(self, x):
# For ViT, output is a tuple; for hybrid, output is tensor
out = self.model(x)
if isinstance(out, tuple):
return out[0] # or out.logits if available
if hasattr(out, "logits"):
return out.logits
return out
#-----------------------------
def find_first_tensor(x):
“”“Recursively find the first torch.Tensor in a nested tuple/list structure.”“”
if isinstance(x, torch.Tensor):
return x
elif isinstance(x, (tuple, list)) and len(x) > 0:
return find_first_tensor(x[0])
else:
raise ValueError(“No tensor found in Grad-CAM output”)
def visualize_gradcam(model, image, device, model_name=‘vit’):
model.eval()
wrapped_model = GradCAMWrapper(model, model_name)
if model_name == 'vit':
target_layers = [model.vit.encoder.layer[-1]]
elif model_name == 'hybrid':
target_layers = [model.vit.encoder.layer[-1]]
else:
logging.error(f"Unsupported model_name: {model_name}")
return None
try:
cam = GradCAM(model=wrapped_model, target_layers=target_layers)
except Exception as e:
logging.error(f"Error initializing GradCAM: {e}")
return None
input_tensor = image.unsqueeze(0).to(device)
try:
grayscale_cam = cam(input_tensor=input_tensor, targets=None)
logging.info(f"Type of grayscale_cam after cam(): {type(grayscale_cam)}")
logging.info(f"Value of grayscale_cam after cam(): {grayscale_cam}")
grayscale_cam = find_first_tensor(grayscale_cam) # <-- This line is key!
grayscale_cam = grayscale_cam.detach().cpu().numpy()
if not isinstance(grayscale_cam, np.ndarray):
grayscale_cam = np.array(grayscale_cam)
except Exception as e:
logging.error(f"Error generating Grad-CAM: {e}")
return None
img = image.permute(1, 2, 0).cpu().numpy()
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)
try:
visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
return visualization
except Exception as e:
logging.error(f"Error creating Grad-CAM visualization: {e}")
return None