Hi I have a following class extending the nn.Module for performing quantization.
class LTQ(nn.Module):
def __init__(self, num_bits):
super(LTQ, self).__init__()
init_range = 2.0
self.n_val = 2**num_bits - 1
self.interval = init_range / self.n_val
self.start = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
self.a = nn.Parameter(
torch.Tensor([self.interval] * self.n_val), requires_grad=True
)
self.scale1 = nn.Parameter(torch.Tensor([1.0]), requires_grad=True)
self.scale2 = nn.Parameter(torch.Tensor([1.0]), requires_grad=True)
self.two = nn.Parameter(torch.Tensor([2.0]), requires_grad=False)
self.zero = nn.Parameter(torch.Tensor([0.0]), requires_grad=False)
self.eps = nn.Parameter(torch.Tensor([1e-3]), requires_grad=False)
def forward(self, x):
x = x * self.scale1
x_forward = x
x_backward = x
step_right = self.zero + 0.0
a_pos = torch.where(self.a > self.eps, self.a, self.eps)
for i in range(self.n_val):
step_right += self.interval
if i == 0:
<some code>
else:
<some code>
Here for my use case I would like this entire LTQ(nn.Module)
to be treated as a monolithic operator like nn.Conv2d
when I run torch._C._jit_pass_inline()
on the output of torch.jit.trace
. In this case the _jit_pass_inline actually traces every line inside the forward function and maps the operators in each line to the respective aten
operators. Can someone please help me if I want the trace to treat this class as an operator?? @ptrblck @smth @albanD your help would be greatly appreciated.