Treat a custom nn.Module class as a monolithic operator

Hi I have a following class extending the nn.Module for performing quantization.

class LTQ(nn.Module):
    def __init__(self, num_bits):
        super(LTQ, self).__init__()
        init_range = 2.0
        self.n_val = 2**num_bits - 1
        self.interval = init_range / self.n_val
        self.start = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
        self.a = nn.Parameter(
            torch.Tensor([self.interval] * self.n_val), requires_grad=True
        self.scale1 = nn.Parameter(torch.Tensor([1.0]), requires_grad=True)
        self.scale2 = nn.Parameter(torch.Tensor([1.0]), requires_grad=True)

        self.two = nn.Parameter(torch.Tensor([2.0]), requires_grad=False) = nn.Parameter(torch.Tensor([0.0]), requires_grad=False)
        self.eps = nn.Parameter(torch.Tensor([1e-3]), requires_grad=False)

    def forward(self, x):

        x = x * self.scale1

        x_forward = x
        x_backward = x
        step_right = + 0.0

        a_pos = torch.where(self.a > self.eps, self.a, self.eps)

        for i in range(self.n_val):
            step_right += self.interval
            if i == 0:
                <some code>
                <some code>

Here for my use case I would like this entire LTQ(nn.Module) to be treated as a monolithic operator like nn.Conv2d when I run torch._C._jit_pass_inline() on the output of torch.jit.trace. In this case the _jit_pass_inline actually traces every line inside the forward function and maps the operators in each line to the respective aten operators. Can someone please help me if I want the trace to treat this class as an operator?? @ptrblck @smth @albanD your help would be greatly appreciated.

I would recommend not to tag specific users as it might discourage others to post a valid answer.

If I understand your question correctly, you are concerned about tracing showing all internal operations instead of one “wrapper”? If so, I wouldn’t know the answer unfortunately.