Dear all,
I have a customized module named MyMod. It is actually a Conv2D, but the weights of Conv2D is actually the product of the trainable weights and a const dictionary.
class MyMod(nn.Module):
def __init__(self, ...., ):
super(MyMod, self).__init__()
self.dct = nn.Parameter(a const dict, requires_grad=False) #dict size torch.Size([1, 9, 3, 3])
self.weight = torch.Tensor(...) #requires_grad=True #weight size torch.Size([32, 32, 9, 1, 1])
def forward(self, x):
filt = torch.sum(self.weight * self.dct, dim=2) # filt size torch.Size([32, 32, 3, 3])
x = F.conv2d(x, filt, stride=1, padding=1, dilation=1, groups=1)
return x
I want to write a register_grad_sampler function for it, but I have no clue. I just copied the register_grad_sampler for Conv2d here. Could you please give me some hints to modify it to support MyMod? Thanks!
@register_grad_sampler([MyMod])
def compute_conv_grad_sample(module, activations, backprops):
n = activations.shape[0]
activations = unfold2d(
activations,
kernel_size=module.kernel_size, # 3
padding=module.padding, # 1
stride=module.stride, # 1
dilation=module.dilation, # 1
)
backprops = backprops.reshape(n, -1, activations.shape[-1])
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
grad_sample = torch.einsum("noq,npq->nop", backprops, activations)
# rearrange the above tensor and extract diagonals.
grad_sample = grad_sample.view(
n,
module.groups,
-1,
module.groups,
int(module.in_channels / module.groups),
np.prod(1),
)
grad_sample = torch.einsum("ngrg...->ngr...", grad_sample).contiguous()
shape = [n] + list(module.weight.shape)
ret = {module.weight: grad_sample.view(shape)}
if module.bias is not None:
ret[module.bias] = torch.sum(backprops, dim=2)
return ret