A question on @register_grad_sampler of a Conv2D-related customized module

Dear all,

I have a customized module named MyMod. It is actually a Conv2D, but the weights of Conv2D is actually the product of the trainable weights and a const dictionary.

class MyMod(nn.Module):
    def __init__(self, ...., ):
        super(MyMod, self).__init__()

        self.dct = nn.Parameter(a const dict, requires_grad=False) #dict size torch.Size([1, 9, 3, 3])
        self.weight = torch.Tensor(...) #requires_grad=True #weight size torch.Size([32, 32, 9, 1, 1])

    def forward(self, x):
        filt = torch.sum(self.weight * self.dct, dim=2) # filt size torch.Size([32, 32, 3, 3])
        x = F.conv2d(x, filt, stride=1, padding=1, dilation=1, groups=1)
        return x

I want to write a register_grad_sampler function for it, but I have no clue. I just copied the register_grad_sampler for Conv2d here. Could you please give me some hints to modify it to support MyMod? Thanks!

def compute_conv_grad_sample(module, activations, backprops):
    n = activations.shape[0]
    activations = unfold2d(
        kernel_size=module.kernel_size, # 3
        padding=module.padding,           # 1
        stride=module.stride,                   # 1
        dilation=module.dilation,              # 1
    backprops = backprops.reshape(n, -1, activations.shape[-1])
    # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
    grad_sample = torch.einsum("noq,npq->nop", backprops, activations)
    # rearrange the above tensor and extract diagonals.
    grad_sample = grad_sample.view(
         int(module.in_channels / module.groups),
     grad_sample = torch.einsum("ngrg...->ngr...", grad_sample).contiguous()
     shape = [n] + list(module.weight.shape)
     ret = {module.weight: grad_sample.view(shape)}
     if module.bias is not None:
         ret[module.bias] = torch.sum(backprops, dim=2)
     return ret

Hey Lei Jiang,

Thanks for your interest! The simplest approach would be the following:

  • Wrap the filt computation into a nn.Module and compute a custom grad sample for this module
  • Then, use the standard Conv2D module on the output of that layer. Hence, you’ll need to only write one grad sampler for your custom approach (filt).

This should make things easier, do not hesitate to grad check your computations on simple examples.

Hope this helps,