Hi, I’m trying to make a CNN model that use custom filters/weights. I started by using a pretrained model and changed it according to my need (figure below to better explanation of the idea). The goal is to have a 3 channels image then filter the input with all filters in each layer. I want to ask if the weights implementation is done right. Here is the code I have done, I know it might not be clean sorry for that.
model = models.alexnet(pretrained= True)
# replace avgpool:
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
model.avgpool= Identity()
# remove some layer:
model.features[8] = nn.MaxPool2d(3, stride=2, padding=0, dilation=1, ceil_mode=False)
model.features = nn.Sequential(*[model.features[i] for i in range(9)])
model.classifier = nn.Sequential(nn.Dropout(p=0.5),
nn.Linear(3,3),
nn.ReLU(inplace=True),
nn.Linear(3,3))
############ 1st hidden layer:
#replace Conv2d:
model.features[0]= nn.Conv2d(3, 1, 3, stride=1, groups=1)
#change weights:
model.features[0].weight = nn.Parameter(data= torch.tensor([
[[1, 2, 1],
[0, 0, 0],
[-1, -2, -1]],
[[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]],
[[2, 1, 0],
[1, 0, -1],
[0, -1, -2]],
[[0, -1, -2],
[1, 0, -1],
[2, 1, 0]]], dtype= torch.float32), requires_grad= True)
###########2nd hidden layer:
#replace 2nd Conv2d:
model.features[3]= nn.Conv2d(3, 1, 3, stride=1, groups=1)
#add weights:
model.features[3].weight = nn.Parameter(data= torch.tensor([
[[0.081, 0.17789, 0.081],
[0, 0, 0],
[-0.081, -0.17789, -0.081]],
[[0.081, 0, -0.081],
[0.17789, 0, -0.17789],
[0.081, 0, -0.081]],
[[0.17789, 0.081 , 0],
[0.081, 0, -0.081],
[0, -0.081, -0.17789]],
[[0, -0.081, -0.17789],
[0.081, 0, -0.081],
[0.17789, 0.081, 0]]], dtype= torch.float32), requires_grad= False)
#######3rd Hidden layer:
#replace 3rd Conv2d:
model.features[6]= nn.Conv2d(3, 1, 3, stride=1, groups=1)
#add weights:
model.features[6].weight= nn.Parameter(data= torch.tensor([
[[0.0455, -0.1789, 0.0455],
[0.1, -0.388, 0.1],
[0.0455, -0.1789, 0.0455]],
[[0.0455, 0.1, 0.0455],
[-0.1789, -0.388, -0.1789],
[0.0455, 0.1, 0.0455]],
[[0.1272, 0, -0.1272],
[0, 0, 0],
[-0.1272, 0, 0.1272]],
[[-0.1272, 0, 0.1272],
[0,0,0],
[0.1272, 0, -0.1272]]], dtype=torch.float32), requires_grad= False)