I am trying to run a custom module which performs quantization:
class Quantizer(nn.Module):
"""
Scalar Quantizer module
Source: https://github.com/mitscha/dplc
"""
def __init__(self, centers=[-1.0, 1.0], sigma=1.0):
super(Quantizer, self).__init__()
self.centers = centers
self.sigma = sigma
def forward(self, x):
centers = x.data.new(self.centers)
xsize = list(x.size())
# Compute differentiable soft quantized version
x = x.view(*(xsize + [1]))
level_var = Variable(centers, requires_grad=False)
dist = torch.pow(x-level_var, 2)
output = torch.sum(level_var * nn.functional.softmax(-self.sigma*dist, dim=-1), dim=-1)
# Compute hard quantization (invisible to autograd)
_, symbols = torch.min(dist.data, dim=-1, keepdim=True)
for _ in range(len(xsize)): centers.unsqueeze_(0) # in-place error
centers = centers.expand(*(xsize + [len(self.centers)]))
quant = centers.gather(-1, symbols.long()).squeeze_(dim=-1)
# Replace activations in soft variable with hard quantized version
output.data = quant
return output
if __name__ == '__main__':
quantizer = Quantizer(centers=[-2, -1, 1, 2])
z = quantizer(torch.Tensor([[1,2.5]]))
This code is supposed to compute a hard nearest-neighbor quantization but use a differentiable relaxation on the backward pass. The idea is to replace output.data
with the hard computation and make it so that autograd only sees the soft forward pass (which will be unused). It yields the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2]] is at version 3; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Now I have seen suggestions to fix this such as changing the inplace operation centers.unsqueeze_(0)
to centers.unsqueeze(0)
, but this is suspicious to me as the whole hard quantization block should be invisible to autograd. Does replacing a Variable’s data field actually work or should I be computing a soft relaxation for backprop some other way?
I am using torch==1.2.0
, but the module was adapted from code using 0.4.1.