Fusing module results in different outputs

Hello,

I recently wrote an UNet-like model that receives 2 inputs(input image, mask) and returns 1 output. Now I am trying to fuse the modules before quantization. The model architecture looks like below.

PartialConvUNet(
  (encoder_1): PartialConvLayer(
    (input_conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (mask_conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (activation): ReLU()
  )
  (encoder_2): PartialConvLayer(
    (input_conv): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
    (mask_conv): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
    (batch_normalization): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (encoder_3): PartialConvLayer(
    (input_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (encoder_4): PartialConvLayer(
    (input_conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (encoder_5): PartialConvLayer(
    (input_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (encoder_6): PartialConvLayer(
    (input_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (encoder_7): PartialConvLayer(
    (input_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (decoder_5): PartialConvLayer(
    (input_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_6): PartialConvLayer(
    (input_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_7): PartialConvLayer(
    (input_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_4): PartialConvLayer(
    (input_conv): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_3): PartialConvLayer(
    (input_conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_2): PartialConvLayer(
    (input_conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_1): PartialConvLayer(
    (input_conv): Conv2d(67, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (mask_conv): Conv2d(67, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
)

Problem:
After fusing the mask_conv + etc operations, the model returns different(wrong) output compared to the one before fusing. What would be the reason?

    for name, module in model.named_modules():
        if type(module) == PartialConvLayer:
            # Fusion in encoder_1 layer 
            if "encoder_1" in name:
                torch.quantization.fuse_modules(module, [['mask_conv', 'activation']], inplace=True)
            # Fusion in encoder_2 ~ encoder_7 layers
            elif "enc" in name:
                torch.quantization.fuse_modules(module, [['mask_conv', 'batch_normalization', 'activation']], inplace=True)
            # Fusion in decoder_2 ~ decoder_7 layers
            elif "decoder_1" not in name:
                torch.quantization.fuse_modules(module, [['mask_conv', 'batch_normalization']], inplace=True)
PartialConvUNet(
  (encoder_1): PartialConvLayer(
    (input_conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): ReLU()
    )
    (activation): Identity()
  )
  (encoder_2): PartialConvLayer(
    (input_conv): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
      (1): ReLU()
    )
    (batch_normalization): Identity()
    (activation): Identity()
  )
  (encoder_3): PartialConvLayer(
    (input_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (batch_normalization): Identity()
    (activation): Identity()
  )
  (encoder_4): PartialConvLayer(
    (input_conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (batch_normalization): Identity()
    (activation): Identity()
  )
  (encoder_5): PartialConvLayer(
    (input_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (batch_normalization): Identity()
    (activation): Identity()
  )
  (encoder_6): PartialConvLayer(
    (input_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (batch_normalization): Identity()
    (activation): Identity()
  )
  (encoder_7): PartialConvLayer(
    (input_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (mask_conv): ConvReLU2d(
      (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (batch_normalization): Identity()
    (activation): Identity()
  )
  (decoder_7): PartialConvLayer(
    (input_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batch_normalization): Identity()
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_6): PartialConvLayer(
    (input_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batch_normalization): Identity()
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_5): PartialConvLayer(
    (input_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batch_normalization): Identity()
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_4): PartialConvLayer(
    (input_conv): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batch_normalization): Identity()
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_3): PartialConvLayer(
    (input_conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batch_normalization): Identity()
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_2): PartialConvLayer(
    (input_conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (mask_conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batch_normalization): Identity()
    (activation): LeakyReLU(negative_slope=0.2)
  )
  (decoder_1): PartialConvLayer(
    (input_conv): Conv2d(67, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (mask_conv): Conv2d(67, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
)