Dear all,
I have a problem in loading the pre-trained configuration of a small variation of a ReNet34 model. I want to use the PyTorch quantisation tool hence I need to implement my own model’s version adding the Stubs. Below, you find my ResNet implementation:
class BasicBlock(nn.Module):
expansion: int = 1def __init__( self, inplanes: int, planes: int, stride: int = 1, groups: int = 1, useBottleneck: bool = False, downsample: bool = False, dilation: int = 1, ) -> None: super().__init__() if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = nn.Conv2d(inplanes,planes,kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample if useBottleneck: filters = [64, 256, 512, 1024, 2048] else: filters = [64, 64, 128, 256, 512] if self.downsample: self.shortcut = nn.Sequential( nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias = False), nn.BatchNorm2d(planes) ) self.stride = stride self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample: identity = self.shortcut(x) out = self.skip_add.add(out, identity) out = self.relu(out) return out
class ResNet(Resnet):
def __init__(self, in_channels = 3, resblock = ResBottleneckBlock, repeat = [3, 4, 6, 3], useBottleneck=True, outputs=2, seed = 12345, pretrained = False): super().__init__(in_channels, resblock, repeat, useBottleneck) if pretrained: checkpoint = 'https://download.pytorch.org/models/resnet50-11ad3fa6.pth' ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=False) new = list(ckpt.items()) my_model_kvpair = self.state_dict() count=0 for key,value in my_model_kvpair.items(): layer_name, weights=new[count] if value.shape == weights.shape: my_model_kvpair[key]= weights count+=1 self.load_state_dict(my_model_kvpair) #for si_cura self.fc = nn.Sequential(nn.Linear(self.fc.in_features,500), nn.ReLU(), nn.Dropout(), nn.Linear(500,outputs))
class ResNet34(Resnet):
def __init__(self, in_channels = 3, resblock = BasicBlock, repeat = [3, 4, 6, 3], useBottleneck=False, outputs=2, seed = 12345, pretrained = False): super().__init__(in_channels, resblock, repeat, useBottleneck) if pretrained: checkpoint = 'https://download.pytorch.org/models/resnet34-b627a593.pth' ckpt = torch.hub.load_state_dict_from_url(checkpoint, progress=True) new = list(ckpt.items()) my_model_kvpair = self.state_dict() count=0 for key,value in my_model_kvpair.items(): layer_name, weights=new[count] if value.shape == weights.shape: print("Ciao") my_model_kvpair[key]= weights count+=1 self.load_state_dict(my_model_kvpair) #for si_cura self.fc = nn.Sequential(nn.Linear(self.fc.in_features,500), nn.ReLU(), nn.Dropout(), nn.Linear(500,outputs))
I found out that my model does not have the parameter num_batches_track. Anyway, I tried to look into the Pytorch source implementation and my BasicBlock coincides with theirs. I’m using the last version of PyTorch so I am pretty sure that there are no version problems. Do you have any suggestion?
Thanks in advance,
Max