Overwrite backward propagation for nn.Parameter for weight quantization

i have written a torch.autograd.Function for quantization as:

class Quantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.round()
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output #pass through 

fake_uniform_quantize = Quantize.apply

how can i apply to a weight parameter? , E.g

self.kernel = nn.Parameter(torch.Tensor(out_channel, in_channel, kernel_size, kernel_size))
self.kernel = fake_uniform_quantize(self.kernel)

But, I end up with the error:

TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'kernel' (torch.nn.Parameter or None expected)

i solved the problem by

self.kernel = nn.Parameter(torch.Tensor(out_channel, in_channel, kernel_size, kernel_size))
kernel = fake_uniform_quantize(self.kernel)

then ...

x = F.conv2d(x, kernel, None, stride=self.stride, padding=0, dilation=self.dilation, groups=self.group)

1 Like