# How to quantize my model? Using quant dequant?

class ISNetDIS(nn.Module):

``````def __init__(self,in_ch=3,out_ch=1):
super(ISNetDIS,self).__init__()

self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage1 = RSU7(64,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage6 = RSU4F(512,256,512)

# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)

# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):

# return muti_loss_fusion(preds,targets)
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)

def compute_loss(self, preds, targets):

# return muti_loss_fusion(preds,targets)
return muti_loss_fusion(preds, targets)

def forward(self,x):

hx = x

hxin = self.conv_in(hx)
#hx = self.pool_in(hxin)

#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)

#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)

#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)

#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)

#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)

#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)

#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)

hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)

hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))

#side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1,x)

d2 = self.side2(hx2d)
d2 = _upsample_like(d2,x)

d3 = self.side3(hx3d)
d3 = _upsample_like(d3,x)

d4 = self.side4(hx4d)
d4 = _upsample_like(d4,x)

d5 = self.side5(hx5d)
d5 = _upsample_like(d5,x)

d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)

# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

return F.sigmoid(d1)
``````