Hi everyone,
I am attempting to compute the analytical derivative of SSIM, which I’ve named dL_ssim_dimg1
. However, I’ve noticed a discrepancy between the analytical gradient and the gradient computed by autograd. The maximum element-wise error across all iterations is approximately (1 \times 10^{-4}) and remains fairly consistent. I’ve reviewed my computation multiple times, and I believe it’s accurate. A detailed derivation can be found here.
When I compare this analytical gradient with other autograd computations, the results are consistent. For example, the gradient for mu1.mean()
is nearly identical, and the L1 loss also shows minimal error. I’m uncertain if these discrepancies can be solely attributed to floating-point precision issues. To me, the error seems too significant to be caused by this alone.
Is there a method to precisely trace how the loss gradient with respect to img1
is calculated? Btw img2
is the reference image used for optimization, is there a way to track its computation process? I’ve already attempted various approaches like rearranging computations, but the error remains roughly consistent.
Code:
std::pair<torch::Tensor, torch::Tensor> ssim(const torch::Tensor& img1, const torch::Tensor& img2, const torch::Tensor& window, int window_size, int channel) {
static const float C1 = 0.01f * 0.01f;
static const float C2 = 0.03f * 0.03f;
const auto mu1 = torch::nn::functional::conv2d(img1, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
const auto mu1_sq = mu1.pow(2);
const auto sigma1_sq = torch::nn::functional::conv2d(img1 * img1, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel)) - mu1_sq;
const auto mu2 = torch::nn::functional::conv2d(img2, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
const auto mu2_sq = mu2.pow(2);
const auto sigma2_sq = torch::nn::functional::conv2d(img2 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel)) - mu2_sq;
const auto mu1_mu2 = mu1 * mu2;
const auto sigma12 = torch::nn::functional::conv2d(img1 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel)) - mu1_mu2;
const auto l_p = (2.f * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1);
const auto cs_p = (2.f * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2);
const auto ssim_map = l_p * cs_p;
auto grad_mu1_wrt_img1 = torch::nn::functional::conv2d(torch::ones_like(mu1), torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
auto lp_x = 2.f * grad_mu1_wrt_img1 * ((mu2 - mu1 * l_p) / (mu1_sq + mu2_sq + C1));
auto cs_x = (2.f / (sigma1_sq + sigma2_sq + C2)) * grad_mu1_wrt_img1 * ((img2 - mu2) - cs_p * (img1 - mu1));
auto dL_ssim_dimg1 = (lp_x * cs_p + l_p * cs_x) / static_cast<float>(img1.size(1) * img1.size(2));
auto loss = ssim_map.mean();
loss.backward();
return {loss, dL_ssim_dimg1};
}