How to implement PACT quantization in PyTorch?

This paper explains a new activation function that has a trainable parameter and is used for quantization of activations (Eq. 1, 2, and 3).

The quantization is as follows:

y = 0.5(|x| - |x - alpha| + alpha)
y_q = round(y * (2 ^ k - 1) / alpha) * alpha / (2 ^ k - 1),

where alpha is the trainable parameter and k is the number of bits.

The partial derivative of y_q with respect to alpha is mentioned in Eq. 3 of the paper.

What is the easiest method of integrating this activation function in PyTorch?
I was thinking of defining an nn.Module that includes alpha as a Parameter. The problem is that there are two sets of gradients here: one for updating alpha and another one for defining gradients with respect to inputs. I assume the latter should be handled in backward() function, but I’m not sure how to update alpha.