Why does modules fusion replace fused modules by nn.Identity?

My question is simple. Why replacing the fused modules by nn.Identity rather than simply deleting them in the function fuse_known_modules (see code)?

The profiler shows that it introduces some non-negligible delay and kinds of suppresses the performance improvement gained by fusion. As you can see in the tracing, the overhead of the two nn.Modules: Identity is almost equivalent of the latency of aten::batch_norm in the original fp32 model, which cancels the improvement done by fusing conv1d with batch_norm.

  1. usually people either jit script the model or export it to get a final maximally performant version. This would remove the nn.Identity
  2. if you delete random modules your model will error.

if you have a model like

self.lin = nn.Linear
self.bn = nn.BatchNorm

forward:
y = self.bn(self.lin(x))
return y

deleting self.bn doesn’t just remove the identity since the model doesn’t know where to send the activation after self.lin

  1. it makes debugging a lot easier. We normally debug quantization by comparing activation for each module output for both the quantized and non-quantized model. If you do a fusion like

location 1: A
location 2: B

fused into

location 1: A+B
location 2: identity

then you can compare the output of location 2 for both models and get an apples to apples comparison. If you compare location 1, for the top case, you’d get the output of A while for the bottom case you’d get the output of B. If you delete location 2 there’s no place that has the same location and same operations.

1 Like