Manually edit intermediate activations

HI, I’m trying to implement the idea of activation quantization. For example, round the outputs of every intermediate layer to their nearest integers during forward propagation and use the quantized activation to compute gradients during backward procedure.

However, If we do something like:

a = Variable(torch.FloatTensor([1]), requires_grad=True)
b = 1.5 * a
c = torch.round(b)
d = c * 2
d.backward()

The result of a.grad will be 0 instead of (1.5 * 2 = 3). Obviously, the op of rounding destroys the gradient chain because it works like directly assigning values to variables.

I tried to play with their grad_fn and manually build up my chain, but it seems we are prevented from manipulating them.

So far among all approaches I have tried, the only one works is to change the data type to integer so that they get “truncated”:

a = Variable(torch.FloatTensor([1]), requires_grad=True)
b = 1.5 * a
c = (b + 0.5).int().float()
d = c * 2
d.backward()

This time a.grad becomes 3 which is what I want, but it’s not elegant at all.

Are there any cleaner way to do so?

The clean way to do this is to write your own torch.autograd.Function, where you write the forward and backward passes that you want.

Got it! Thanks Kaixhin.