Im quantizing nn.Modules (with my own API, not torch.quantization). I want to train black box user modules and convert them by simply replacing some of the modules afterwards. So far so good, but I don’t know how to properly access the control flow, which I need to to know for architectures that are not only sequential: For example, if there are residual adding connections in the forward pass, for example like so:
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.sequentialstack(x) + x
x = self.softmax(x)
return x
which Id like to replace by something like
def forward(self, x):
x = F.relu(self.fc1_quant(x))
x_residual_cache = x
x = self.sequentialstack_quant(x)
x = self.quantized_adder(x, x_residual_cache)
x = self.softmax_quant(x)
return x
What is the recommended way to do this? I have read the Intro to Torchscript but dont see immediately what the recommended way to find out about the internal control flow is; so before I throw regexes at scripted_model.forward.code, I thought I’d ask what the right way to do this is.