@ptrblck
Text description first: (code follows)
I train the model from scratch, there is only one sub-graph that does the feature extraction, but I use it multiple times, once per image, in a single forward pass, so the feature extractor params are shared. You can think about it as Siamese network gone wild
Multiple images are inserted into the forward function,
Each image goes through:
- An InceptionResnetV2 feature extractor, using only the first few layers, and no fully connected layers.
- A global max pooling layer, so now we have a 1d vector of features per image (for example, 1088 features)
After we did this for, for example, 10 images, we now have 10 individual 1d vectors, each of size 1088.
We stack them into a single matrix, which has the shape (10,1088)
and, for example, perform another global max pooling on them, reaching a 1d feature vector of size 1088, which represents max features from all images.
From that point, standard fully connected layer and activation that is relevant to negative log likelihood is used.
Code:
Will try to demonstrate it here in code. The code is not self contained but hopefully demonstrates the scenario well, at least for static viewing.
Please tell me if it’s not and I’ll improve it.
class ReproduceIssue(nn.Module):
def __init__(self):
super().__init__()
#this is an inception resnet v2 model, but only the first K layers (and no fully connect layers)
self.feature_extractor = inceptionresnetv2.inceptionresnetv2(
pretrained=False,
num_classes=0,
logical_units_num=14, #keep only the first 14 parts (it's up to mixed_6a)
input_channels_num=1,
final_global_pooling='max' #end it with global max pooling to get a 1d vector per sample
)
self.fc1 = MyDense(1088, 2, activation=None, batch_norm=False, dropout_rate=None)
self._init_vars() #...
def forward(self, *args):
extracted_features = []
for x in args:
#
#curr_feat = self.feature_extractor(x) #orig before checkpointing
curr_feat = checkpoint(self.feature_extractor, x)
extracted_features.append(curr_feat)
#now that we have all features, stack and perform an additional global max pooling
stacked_slices_features = torch.stack(slices_features)
stacked_slices_features = stacked_slices_features[:, 0, ...]
permuted_stacked = stacked_slices_features.permute(2, 1, 0, 3)
vol_features = F.max_pool2d(permuted_stacked, kernel_size=permuted_stacked.shape[2:])
#now, take our 1d vector and continue into a fully connected layer, and final activation
logits = self.fc1(vol_features[:, :, 0, 0])
preds = nn.LogSoftmax(dim=1)(logits)
return preds
Note: For the sake of simplicity, assume that a minibatch will always be of size 1.
(My code does support minibatches of arbitrary sizes, but it overcomplicates the demonstration code so I did not include it)