Sobel loss per channel of multiple gray image like

I have an input model which is a bayer image → (b, 4, h, w)
I want to apply sobel loss to each channel (treat it as if it was gray image):

import torch
import torch.nn as nn
import torch.nn.functional as F

from .data_record import DataRecord

class GradLayer(nn.Module):
    def __init__(self):
        super(GradLayer, self).__init__()
        kernel_v = [[1, 0, -1],
                    [2, 0, -2],
                    [1, 0, -1]]
        kernel_h = [[1, 2, 1],
                    [0, 0, 0],
                    [-1, -2, -1]]
        kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
        kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
        self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False)
        self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False)

    def get_gray(x):
        Convert image to its gray one.
        gray_coeffs = [65.738, 129.057, 25.064]
        convert = x.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
        x_gray = x.mul(convert).sum(dim=1)
        return x_gray.unsqueeze(1)

    def forward(self, x):
        if x.shape[1] == 3:
            x = self.get_gray(x)

        _n_groups = x.shape[1]
        x_v = F.conv2d(x, self.weight_v.repeat((_n_groups, 1, 1, 1)), padding=1, groups=_n_groups)
        x_h = F.conv2d(x, self.weight_h.repeat((_n_groups, 1, 1, 1)), padding=1, groups=_n_groups)
        x = torch.sqrt(torch.pow(x_v, 2) + torch.pow(x_h, 2) + 1e-6)

        return x

class SobelLoss(nn.Module):
    def __init__(self):
        super(SobelLoss, self).__init__()
        self.loss = nn.L1Loss()
        self.grad_layer = GradLayer()
        self.unshuffle = nn.PixelUnshuffle(2)

    def forward(self, loss_input: DataRecord) -> torch.Tensor:
        _target_seq = loss_input.target_sequence
        target = self.unshuffle(_target_seq[:, -1]) if len(_target_seq.shape) == 5 else _target_seq
        prediction = self.unshuffle(loss_input.prediction) if len(_target_seq.shape) == 5 else loss_input.prediction

        output_grad = self.grad_layer(prediction)
        gt_grad = self.grad_layer(target)
        return self.loss(output_grad, gt_grad)

Now I’m using group convolution but I’m getting confused on the group vs replicate the kernel:

group: x_v = F.conv2d(x, self.weight_v.repeat((_n_groups, 1, 1, 1)), padding=1, groups=_n_groups)
replicate: x_v = F.conv2d(x, self.weight_v.repeat((_n_groups, _n_groups, 1, 1)), padding=1)

Can you please help elaborate on the matter and maybe give some ituition to the difference.
In both I get the desired output shape (b, 4, h, w) but I wonder if it really the gradients or some unknown I’m missing


It depends on what you are trying to do. As you observed, both give you the expected output shape, but the difference is that the first one will treat each channel independently, as you have _n_groups equal to the number of channels. That is, for each output channel, the result only depends on a single input channel—x_v[:,0,:,:] only depends on x[:,0,:,:], and x_v[:,1,:,:] only depends on x[:,1,:,:] and so on.

For the second case, you would add together the the result across all of the channels as the number of groups is 1 by default. So now x_v[:,0,:,:] is computed using x[:,0,:,:] and x[:,1,:,:] and x[:,2,:,:] and so on. In fact, since your filters are the same for each output channel in this case, you should be able to verify that x_v[:,0,:,:] == x_v[:,1,:,:] == x[:,2,:,:] == x[:,3,:,:] (barring possible minor numerical differences).

If we call x_v1 x_v when it is computed according to the first method (_n_groups = x.shape[1]) and x_v2 when computed according to the second method, then x_v2[:,0,:,:] == x_v1[:,0,:,:] + x_v1[:,1,:,:] + x_v1[:,2,:,:] + x_v1[:,3,:,:]. Note that these two equations are a consequence of the filter repeats but not true in the general case where the filters are allowed to have arbitrary values.

As it seems strange that you would want to have the same values across all channels, I assume you would want to do the former, but I don’t know the context of your use-case.

1 Like