Getting follwoing error:
Traceback (most recent call last):
File “gaussian_blur.py”, line 36, in
sigma = estimate_sigma(original_image, blurred_image, args.init_sigma)
File “gaussian_blur.py”, line 18, in estimate_sigma
loss = F.mse_loss(blurred, blurred_image)
File “/nfs/tools/humans/conda/envs/datacap/lib/python3.7/site-packages/torch/_tensor.py”, line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/nfs/tools/humans/conda/envs/datacap/lib/python3.7/site-packages/torch/autograd/init.py”, line 175, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
import torch
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
from PIL import Image
import argparse
import numpy as np
from ipdb import set_trace as bbdef estimate_sigma(original_image, blurred_image, init_sigma=None):
if init_sigma is None:
init_sigma = 1.0
bb()
sigma = torch.tensor(init_sigma, requires_grad=True)
optimizer = torch.optim.Adam([sigma], lr=0.1)for _ in range(100): blurred = GaussianBlur(kernel_size=(int(6 * sigma + 1), int(6 * sigma + 1)))(original_image) loss = F.mse_loss(blurred, blurred_image) optimizer.zero_grad() loss.backward() optimizer.step() return sigma.detach()
if name == “main”:
parser = argparse.ArgumentParser(description=“Estimate Gaussian blur sigma”)
parser.add_argument(“original_image”, type=str, help=“Path to the original image”)
parser.add_argument(“blurred_image”, type=str, help=“Path to the blurred image”)
parser.add_argument(“–init_sigma”, type=float, default=None, help=“Initial guess for sigma”)
args = parser.parse_args()original_image = torch.tensor(np.array(Image.open(args.original_image)), dtype=torch.float32).permute(2, 0, 1)[None, ...] blurred_image = torch.tensor(np.array(Image.open(args.blurred_image)), dtype=torch.float32).permute(2, 0, 1)[None, ...] sigma = estimate_sigma(original_image, blurred_image, args.init_sigma) print(f"Estimated sigma: {sigma.item()}")