my model as following:
from torch.quantization import QuantStub, DeQuantStub
class dehaze_net(nn.Module):
def __init__(self):
super(dehaze_net, self).__init__()
self.relu = nn.ReLU(inplace=False)
self.e_conv1 = nn.Conv2d(3,3,1,1,0,bias=True)
self.e_conv2 = nn.Conv2d(3,3,3,1,1,bias=True)
self.e_conv3 = nn.Conv2d(6,3,5,1,2,bias=True)
self.e_conv4 = nn.Conv2d(6,3,7,1,3,bias=True)
self.e_conv5 = nn.Conv2d(12,3,3,1,1,bias=True)
self.skip_add = nn.quantized.FloatFunctional()
self.quant = QuantStub()
self.dequant = DeQuantStub()
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.quant(x)
source = []
source.append(x)
x1 = self.relu(self.e_conv1(x))
x2 = self.relu(self.e_conv2(x1))
concat1 =self.skip_add.cat((x1,x2), 1)
x3 = self.relu(self.e_conv3(concat1))
concat2 = self.skip_add.cat((x2, x3), 1)
x4 = self.relu(self.e_conv4(concat2))
concat3 = self.skip_add.cat((x1,x2,x3,x4),1)
x5 = self.relu(self.e_conv5(concat3))
#clean_image = self.relu(self.skip_add.add((x5 * x) - x5, 1))
clean_image = self.skip_add.add_relu(self.skip_add.add(self.skip_add.mul(x5,x),-x5), 1)
clean_image = self.dequant(clean_image)
return clean_image
i load the pretrained model ,and predict the image ,it is right .as follows
def load_model(model_file):
model = dehaze_net()
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.to('cpu')
return model
myModel = load_model(float_model_file).to('cpu')
myModel.eval()
myModel.qconfig = torch.quantization.default_qconfig
print(myModel.qconfig)
torch.quantization.prepare(myModel, inplace=True)
validation(myModel, data_loader_test)
but when i excute the following codes ,it give some errors:
torch.quantization.convert(myModel, inplace=True)
validation(myModel, data_loader_test)
errors:
<ipython-input-44-674cc7c0f6b1> in forward(self, x)
47
48 #clean_image = self.relu(self.skip_add.add((x5 * x) - x5, 1))
---> 49 clean_image = self.skip_add.add_relu(self.skip_add.add(self.skip_add.mul(x5,x),-x5), 1)
50
51 clean_image = self.dequant(clean_image)
RuntimeError: No function is registered for schema aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor on tensor type QuantizedCPUTensorId; available functions are CPUTensorId, CUDATensorId, MkldnnCPUTensorId, SparseCPUTensorId, SparseCUDATensorId, VariableTensorId
thanks