RuntimeError: No function is registered for schema aten::thnn_conv2d_forward

1, I fuse conv2d and batchnorm2d
2,Specify the configuration of the quantization methods
3, Use the [ torch.quantization.prepare() ]
4,Calibrate the model by running inference against a calibration dataset
but it comes to setp 4, i got this

I don’t know why the code runs with torch.nn.modules.conv instead of torch.quantized
That`s the code:

model.to(‘cpu’)
model.eval()
model.qconfig = torch.quantization.get_default_qconfig(‘fbgemm’)
torch.quantization.prepare(model, inplace=True)
evaluate(model, data_loader_test, 10)
torch.quantization.convert(model, inplace=True)
top1, top5, loss = evaluate(model, data_loader_test) —> WRONG

Did you add QuantStub and DeQuantStub correctly in the original model? Please follow (beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.1.1+cu121 documentation to do quantization.

Hello! I’m just following the tutorial to quantize, and I got the same error.


I try to use a simple Conv + BN + ReLU like this

class ConvBNReLU(nn.Sequential):
	def __init__(self, in_channel, out_channel, kernel_size, stride):
		padding = (kernel_size - 1) // 2
		super(ConvBNReLU, self).__init__(
			nn.Conv2d(
				in_channel,
				out_channel,
				kernel_size,
				stride,
				padding,
				bias=False
			),
			nn.BatchNorm2d(out_channel),
			nn.ReLU(inplace=False)
		)

class CNN(nn.Module):
	def __init__(self):
		super(CNN, self).__init__()
		# input size 3 * 32 * 32
		self.conv1 = ConvBNReLU(3, 16, 3, 1)

		self.conv2 = ConvBNReLU(16, 32, 3, 1)

		self.quant = QuantStub()
		self.dequant = DeQuantStub()

		self.out = nn.Linear(32 * 32 * 32, 10)

	def forward(self, x):
		x = self.quant(x)
		x = self.conv1(x)
		x = self.conv2(x)
		x = x.contiguous()
		x = x.view(-1, 32 * 32 * 32)
		x = self.out(x)
		x = self.dequant(x)
		return x

	def fuse_model(self):
		for m in self.modules():
			if type(m) == ConvBNReLU:
				torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)

model after adding observer

ConvBNReLU(
  (0): ConvBnReLU2d(
    (0): Conv2d(
      3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (observer): MinMaxObserver(min_val=None, max_val=None)
    )
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(
      (observer): MinMaxObserver(min_val=None, max_val=None)
    )
  )
  (1): Identity()
  (2): Identity()
)

model after quantized convert

ConvBNReLU(
  (0): ConvBnReLU2d(
    (0): Conv2d(
      3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (observer): MinMaxObserver(min_val=-3.1731998920440674, max_val=3.2843430042266846)
    )
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(
      (observer): MinMaxObserver(min_val=0.0, max_val=13.862381935119629)
    )
  )
  (1): Identity()
  (2): Identity()
)

The remaining works are just following the tutorial.
I wonder why the layers did not be converted to quantized layer.

We can’t quantize batchnorm, you’ll need to fuse batchnorm layer first, please follow the tutorial(https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#model-architecture) to fuse the conv and batchnorm relu first. You can also take a look at our test to see how to use the fusion API: https://github.com/pytorch/pytorch/blob/master/test/test_quantization.py#L949

Hello Jerry, thanks for reply!
I have check the test doing fusion, but I still have some quesion
Here is my model before fused

CNN(
  (conv1): ConvBNReLU(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): ConvBNReLU(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (out): Linear(in_features=32768, out_features=10, bias=True)
)

And here is after fused

CNN(
  (conv1): ConvBNReLU(
    (0): ConvBnReLU2d(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Identity()
    (2): Identity()
  )
  (conv2): ConvBNReLU(
    (0): ConvBnReLU2d(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Identity()
    (2): Identity()
  )
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (out): Linear(in_features=32768, out_features=10, bias=True)
)

It seems like conv BN ReLU have been converted to ConvBNReLU2d, did I miss anything?
I’m doing with post training, is it only conv ReLU can be quantized?

And the model in section “Define dataset and data loaders”(https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html#define-dataset-and-data-loaders)
I am confused why the model before fusion is conv BN ReLU, but after fusion it turned out only ConvReLU2d.
Did it remove BN while doing fusion?

the fusion result looks correct.
Yes, fusion means we are going to fuse parameters of batchnorm into conv since batchnorm is just a linear transformation of the input at inference time. see https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/fusion.py#L7 for code to fuse conv and bn.

1 Like

I found that I forgot to change my model into evaluate mode.
It works now!
Thanks for help!

Is it possible to fuse nn.BatchNorm1d(512) with Linear and Relu ?

I am getting the following error:
NotImplementedError: Cannot fuse modules: (<class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.batchnorm.BatchNorm1d’>, <class ‘torch.nn.modules.activation.ReLU’>)

No, we only have fusion between conv and batchnorm right now