Hi @ptrblck , I tried to follow this method as suggested for my network.
I want to extract all 4 layer features in a single go: I am unsure if they are overwritten as the layer name is same in SSL. Can you please suggest if my method is correct, if not please suggest me a better method
Thanks:
def forward(self, x, type='SL'):
if type == 'SL':
h0 = self.conv1(x) #Image(x) with transforms1
h0 = relu(self.bn1(h0))
h0 = self.maxpool(h0)
h1 = self.layer1(h0)
### 1. extract and save features here ###
h2 = self.layer2(h1)
h3 = self.layer3(h2)
h4 = self.layer4(h3)
m1_ = self.f_conv1(x) #Image(x) with transforms1
m1 = m1_ * h1
### 2. extract and save features here ###
m2_ = self.f_conv2(m1)
m2 = m2_ * h2
m3_ = self.f_conv3(m2)
m3 = m3_ * h3
m4_ = self.f_conv4(m3)
m4 = m4_ * h4
out = self.avgpool(m4)
out = out.view(out.size(0), -1)
y = self.linear(out)
return y
if type=='SSL':
h0 = self.conv1(x1) #Image(x) with transforms2
h0 = relu(self.bn1(h0))
h0 = self.maxpool(h0)
h1 = self.layer1(h0)
### 3. extract and save features here ###
h2 = self.layer2(h1)
h3 = self.layer3(h2)
h4 = self.layer4(h3)
feat = self.avgpool(h4)
h01 = self.conv1(x2) #Image(x) with transforms3
h01 = relu(self.bn1(h01))
h01 = self.maxpool(h01)
h11 = self.layer1(h01)
### 4. extract and save features here ###
h21 = self.layer2(h11)
h31 = self.layer3(h21)
h41 = self.layer4(h31)
feat1 = self.avgpool(h41)
I have registered hook and extracted as as follows
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = Net(device, 10)
model = model.to(device)
model.layer1.register_forward_hook(get_activation('layer1'))
model.f_conv1.register_forward_hook(get_activation('f_conv1'))
x = torch.randn(1, 3, 128, 128)
output = model(x, type='SL')
print(activation['layer1'])
print(activation['f_conv1'])
model.layer1.register_forward_hook(get_activation('layer1'))
model.layer1.register_forward_hook(get_activation('layer1'))
output = model(x1, type='SSL')
print(activation['layer1'])
output = model(x2, type='SSL')
print(activation['layer1'])