RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. could you please solve it

This is my code
"
import warnings
warnings.filterwarnings(‘ignore’)
warnings.simplefilter(‘ignore’)
import torch, yaml, cv2, os, shutil
import torch.nn as nn
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
import time
import os.path
from tqdm import trange
from PIL import Image
from models.yolo import Model
from utils.torch_utils import intersect_dicts
from utils.datasets import letterbox
from utils.general import xywh2xyxy
from utils.general import non_max_suppression, scale_coords
import torch

#from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam import (
GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
LayerCAM, FullGrad, GradCAMElementWise
)
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients

class yolov7_heatmap:
def init(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):
device = torch.device(device)
#ckpt = torch.load(weight,map_location=torch.device(‘cpu’))
ckpt = torch.load(weight)
model_names = ckpt[‘model’].names
csd = ckpt[‘model’].float().state_dict() # checkpoint state_dict as FP32
model = Model(cfg, ch=3, nc=len(model_names)).to(device)
csd = intersect_dicts(csd, model.state_dict(), exclude=[‘anchor’]) # intersect
model.load_state_dict(csd, strict=False) # load
model.eval()

    target_layers = [eval(layer)]
    method = eval(method)

    colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int_)
    self.__dict__.update(locals())

def post_process(self, result):
    boxes_ = result[0][..., :4]
    logits_ = []
    #print(logits_)
    for data in result[1]:
        bs, n, w, h, _ = data.size()
        logits_.append(data.reshape((bs, n * w * h, _)))
    logits_ = torch.cat(logits_, dim=1)[..., 4:]
    sorted, indices = torch.sort(logits_[..., 0], descending=True)
    logits_ = logits_[0][indices[0]]
    logits_[:, 0] = torch.sigmoid(logits_[:, 0])
    print("post result size------------ : ",logits_.size())
    print("logits_--------------: " , logits_)
    print("boxes_",boxes_)
    print("boxes_[0][indices[0]]----: ",boxes_[0][indices[0]])
    print("xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()-------------:",  xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy())
    return logits_, xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()
    
def post_process_2(self,result):
    """
    Function to extract coordinates (boxes) and confidence scores from the tensor.
    Args:
        tensor (torch.Tensor): Input tensor containing detection data.
    Returns:
        torch.Tensor: Extracted coordinates.
        list of str: Extracted and formatted confidence scores.
    """
    # Extract the coordinates
    boxes_ = result[..., :4]
    # Extract the confidence scores
    confidence = result[..., 4]

    # Format confidence values
    confidence =torch.sigmoid(confidence)
    print("boxes _ : ", boxes_)
    print("xywh2xyxy(boxes_).cpu().detach().numpy() --- :", boxes_.cpu().detach().numpy())
    return confidence, xywh2xyxy(boxes_).cpu().detach().numpy()


def draw_detections(self, box, color, name, img):
    xmin, ymin, xmax, ymax = list(map(int, list(box)))
    f=open("Traffic_2560x1600_30_0118_060.csv","a")
    f.write(name)
    f.write(" : ")
    f.write(str(xmin))
    f.write(",")
    f.write(str(ymin))
    f.write(",")
    f.write(str(xmax))
    f.write(",")
    f.write(str(ymax))
    f.write("\n")
    f.close()
    
    cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
    cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)
    
    return img
    
def combine_images(self, folder_path, output_path):
    images = []
    
    # Iterate through all files in the folder
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            img_path = os.path.join(folder_path, filename)
            img = Image.open(img_path)
            images.append(img)
    
    # Get the dimensions of the first image
    width, height = images[0].size
    
    # Create a new image with a height that can accommodate all images
    combined_img = Image.new("RGB", (width, len(images) * height))
    
    # Paste each image into the combined image
    y_offset = 0
    for img in images:
        combined_img.paste(img, (0, y_offset))
        y_offset += height
    
    # Save the combined image
    combined_img.save(f'{output_path}/{filename}.png')
    
def __call__(self, img_path, save_path):
    # remove dir if exist
    if os.path.exists(save_path):
        shutil.rmtree(save_path)
    # make dir if not exist
    os.makedirs(save_path, exist_ok=True)
    

    # img process
    img = cv2.imread(img_path)
    img = letterbox(img)[0]
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.float32(img) / 255.0
    tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)
    
    count=1
    #torch.autograd.set_detect_anomaly(True)
    # init ActivationsAndGradients
    print(f'The count is : {count}  start................')
    grads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)
    # get ActivationsAndResult
    result = grads(tensor)
    activations = grads.activations[0].cpu().detach().numpy()
    #print("results for the images after going through the model........................... \n", result )
    
    conf_thres = 0.25
    iou_thres = 0.5 
    predictions = non_max_suppression(result[0], conf_thres, iou_thres)
    
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
     list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """
    
    #print("activations are ..................\n", activations)
    
    print("Here is the result ........................\n")
    print(result)
    print("\n Here is the predictions ---------------------------\n")
    #print(torch.eq(result,predictions))

    # postprocess to yolo output
    post_result, post_boxes = self.post_process(result)
    print("\n result size------------ : ")
    #.......................................................................................................................................
    print(" type of all ------------ : \n")
    def check_tensor_type(tensor):
        if isinstance(tensor, tuple):
            print("This is a tuple-type tensor")
        elif isinstance(tensor, list):
            print("This is a list-type tensor")
        elif isinstance(tensor, torch.Tensor):
            print("This is a regular PyTorch tensor")
        else:
            print("This is not a recognized tensor type")
    check_tensor_type(result)
    check_tensor_type(predictions)
    check_tensor_type(post_boxes)
    check_tensor_type(post_result)
    
    
    post_result, post_boxes = self.post_process_2(predictions[0])
    print("\n post_result[0]------------------: ",post_result[0])
    print("post_result[0].requires_grad----------:",post_result[0].requires_grad)
    self.model.zero_grad()
    if self.backward_type == 'conf':
            #confidence = confidence.type(torch.cuda.FloatTensor)
            post_result[0].backward(retain_graph=True, create_graph=True)
    else:
        # get max probability for this prediction
        class_id.backward(retain_graph=True)
        
    gradients = grads.gradients[0]
    b, k, u, v = gradients.size()
    weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())
    weights = weights.reshape((b, k, 1, 1))
    saliency_map = np.sum(weights * activations, axis=1)
    saliency_map = np.squeeze(np.maximum(saliency_map, 0))
    saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))
    saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
    #if (saliency_map_max - saliency_map_min) == 0:
     #   continue
    saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
    # add heatmap and box to image
    cam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)
    cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())], f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}', cam_image)
    total_frames=total_frames+1
    cam_image = Image.fromarray(cam_image)
    #cv2.imwrite("tokkalo_reuslt.png",cam_image)
    cam_image.save(f'{save_path}/{i}.png')
    
#&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&7    

def get_params():
params = {
‘weight’: ‘yolov7.pt’,
‘cfg’: ‘cfg/training/yolov7.yaml’,
‘device’: ‘cuda:0’,
‘method’: ‘GradCAM’, # GradCAMPlusPlus, GradCAM, XGradCAM
‘layer’: ‘model.model[-2]’,
‘backward_type’: ‘conf’, # class or conf
‘conf_threshold’: 0.6, # 0.6
‘ratio’: 0.02 # 0.02-0.1
}
return params

if name == ‘main’:
model = yolov7_heatmap(**get_params())
f=open(“Traffic_2560x1600_30_0118_060.csv”,“w”)
f.write(“name : xmin, ymin, xmax, ymax \n”)
f.close()
model(‘Input/image3.jpg’, ‘output/output/image3_2’)

'''
input_folder='Input/Traffic_2560x1600_30_0118_10frames'
k=0
for image in os.listdir(input_folder):
    print(image)
    #print(image.type)
    image_path=os.path.join(input_folder,image)
    output_path=os.path.join('/mnt/beegfs/home/vdhulipudi2023/venv_example/pytorch-grad-cam/tutorials/yolov7/output/10frames',image)
    print(output_path)
    model(image_path,output_path)
    ''' 

and I am getting this error
“Traceback (most recent call last):
File “/mnt/beegfs/home/vdhulipudi2023/venv_example/pytorch-grad-cam/tutorials/yolov7/z.py”, line 251, in
model(‘Input/image3.jpg’, ‘output/output/image3_2’)
File “/mnt/beegfs/home/vdhulipudi2023/venv_example/pytorch-grad-cam/tutorials/yolov7/z.py”, line 202, in call
post_result[0].backward(retain_graph=True, create_graph=True)
File “/mnt/beegfs/home/vdhulipudi2023/venv_example/new_env/lib/python3.12/site-packages/torch/_tensor.py”, line 525, in backward
torch.autograd.backward(
File “/mnt/beegfs/home/vdhulipudi2023/venv_example/new_env/lib/python3.12/site-packages/torch/autograd/init.py”, line 267, in backward
_engine_run_backward(
File “/mnt/beegfs/home/vdhulipudi2023/venv_example/new_env/lib/python3.12/site-packages/torch/autograd/graph.py”, line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [15, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).” how can i resolve this ? actually after doing NMS I got this error

Would the following work for you? Automatic differentiation package - torch.autograd — PyTorch 2.3 documentation

Hi Vaishnavi!

Soulitzer’s suggestion of using allow_mutation_on_saved_tensors():
may be appropriate for your use case.

But if you don’t need to modify tensors inplace (maybe you’re doing it
inadvertently somewhere), it may makes sense to fix the underlying issue.

To track things down, start with the information in your error message.
You have a tensor of shape [15, 1] that is being modified inplace.
Where in your forward pass do you have such a tensor (or tensors)? Its
._version property is changing from 0 to 2. Try printing out t._version
at various place in your code to see where t._version increases from 0
to 1 (one inplace modification) and then from 1 to 2 (a second inplace
modification).

Does your suspect tensor have ._version = 2 just before your call to
post_result[0].backward(retain_graph=True, create_graph=True)?

(Note, if you find and fix any inplace-modification errors for a given tensor,
errors for additional tensors may show up – autograd aborts the backward
pass by raising the RuntimeError you saw, so it only flags the first error
it encounters.)

You should use anomaly detection – it provides additional information that
can help you track down your error.

Some comments that may or may not be relevant to your error:

You do define model in your __init__() function, but, in the code as
posted, it is a local variable to __init__(). You do not define self.model.

Your __call__() function uses self.model, but it is not defined in your
code, as posted.

Assigning into a tensor by indexing the tensor is an inplace modification
and is sometimes the cause of inplace-modification errors. But it doesn’t
appear that you are using the result of post_process() in your forward
pass.

The following post discusses how to debug inplace-modification errors
in some detail:

Good luck!

K. Frank