Selecting a particular layer from the pytorch pre-trained model

I have pre-trained model which has 227 layers. The last but one layer is the MLP. The pre-trained model is a state_dict which has key value pair. The values of all the keys are tensor objects. Now I have access the the MLP layer and extract the embeddings on the test images.How to do it?

‘mlp_head.0.weight’, tensor([0.9932, 0.9924, 1.0286, 0.9775, 1.0295, 1.0285, 1.0340, 0.9975, 1.0089,
1.0182, 1.0191, 1.0727, 1.0375, 1.0204, 1.0035, 1.0275, 1.0305, 1.0383,
0.9900, 1.0200, 1.0360, 1.0018, 1.0014, 0.9629, 0.9696, 1.0488, 0.9846,
1.0166, 1.0004, 1.0193, 1.0052, 1.0267, 1.0243, 1.0107, 0.9456, 1.0478,
0.9933, 1.0224, 1.0337, 0.8880, 1.0013, 0.9733, 1.0558, 0.9860, 1.0284,
1.0364, 1.0264, 1.0486, 1.0706, 1.0193, 0.9990, 1.0349, 1.0028, 0.9599,
1.0390, 1.0043, 1.0125, 1.0550, 1.0478, 1.0416, 0.9980, 0.9994, 0.9806,
0.9902, 1.0020, 1.0574, 0.9765, 1.0064, 0.9898, 0.9972, 1.0216, 1.0402,
0.9730, 1.0502, 1.0091, 0.9761, 1.0121, 1.0241, 1.0432, 0.9697, 1.0423,
1.0502, 1.0249, 1.0326, 1.0516, 1.0145, 1.0687, 0.9958, 1.0580, 0.9947,
0.8908, 0.9941, 1.0354, 1.0422, 1.0353, 1.0140, 1.0395, 1.0555, 1.0154,
1.0325, 1.0212, 0.9976, 1.0089, 0.9949, 1.0112, 0.8455, 1.0415, 1.0202,
1.0137, 0.9632, 1.0298, 1.0080, 1.0302, 0.9589, 1.0115, 1.0142, 1.0384,
1.0214, 1.0406, 1.0281, 1.0158, 1.0356, 1.0363, 0.9789, 1.0268, 0.9997,
0.9798, 1.0295, 1.0490, 1.0406, 1.0319, 1.0491, 0.9804, 1.0375, 1.0092,
1.0166, 0.9765, 1.0302, 1.0472, 1.0011, 0.9737, 1.0238, 0.9457, 1.0232,
1.0547, 1.0108, 1.0015, 1.0235, 1.0073, 0.9836, 1.0202, 1.0412, 0.9726,
1.0105, 1.0332, 1.0335, 1.0399, 0.9981, 0.9722, 0.9925, 1.0199, 1.0253,
0.9931, 1.0254, 1.0412, 1.0363, 0.9674, 1.0248, 0.9878, 1.0503, 1.0360,
0.9900, 1.0171, 1.0059, 1.0310, 1.0047, 1.0354, 1.0173, 1.0377, 1.0234,
1.0059, 1.0264, 1.0382, 0.9575, 1.0607, 1.0096, 1.0119, 0.9823, 0.9702,
1.0038, 1.0403, 0.9807, 1.0222, 0.9837, 0.9580, 1.0418, 1.0393, 1.0006,
1.0109, 0.9838, 1.0098, 1.0321, 1.0173, 1.0424, 0.9828, 1.0320, 1.0090,
0.7263, 1.0356, 1.0334, 1.0278, 1.0256, 1.0093, 1.0276, 1.0390, 1.0187,
1.0042, 1.0446, 1.0096, 1.0331, 1.0213, 1.0019, 1.0301, 1.0093, 1.0420,
1.0412, 0.9844, 0.9913, 1.0505, 1.0311, 0.9108, 1.0195, 1.0426, 1.0111,
1.0421, 1.0140, 1.0066, 1.0258, 0.9818, 1.0153, 0.9956, 1.0241, 1.0067,
1.0110, 1.0109, 1.0443, 1.0040, 1.0444, 1.0420, 1.0453, 1.0007, 1.0426,
1.0081, 0.9448, 1.0177, 1.0033, 1.0049, 0.9890, 1.0073, 1.0198, 1.0485,
1.0486, 1.0514, 1.0442, 1.0128, 1.0035, 0.9939, 1.0338, 1.0136, 1.0490,
1.0483, 1.0213, 0.9902, 1.0412, 0.9877, 0.9700, 1.0446, 1.0773, 1.0425,
0.9962, 0.9884, 0.9983, 1.0213, 0.9991, 0.9959, 1.0060, 1.0242, 1.0087,
0.9715, 1.0195, 1.0285, 1.0259, 1.0072, 0.9626, 1.0100, 1.0173, 0.9628,
1.0260, 0.9918, 1.0367, 0.9541, 0.9944, 1.0355, 1.0081, 0.9988, 0.9805,
1.0187, 1.0623, 0.9889, 1.0154, 1.0200, 1.0226, 1.0471, 1.0404, 1.0180,
1.0497, 1.0200, 1.0021, 1.0176, 1.0192, 0.9869, 1.0124, 1.0370, 1.0522,
0.9919, 1.0465, 0.9398, 0.9798, 1.0182, 1.0265, 1.0638, 1.0409, 1.0433,
1.0165, 1.0251, 1.0108, 1.0179, 1.0334, 1.0235, 0.9636, 0.9982, 0.9862,
1.0315, 1.0203, 1.0328, 1.0409, 1.0232, 1.0200, 1.0183, 1.0322, 1.0324,
0.9795, 0.9581, 1.0115, 0.9940, 1.0160, 1.0612, 0.0848, 0.9953, 0.9940,
0.9840, 1.0136, 1.0366, 1.0321, 1.0162, 1.0296, 1.0038, 0.9978, 1.0298,
0.9862, 1.0045, 1.0084, 1.0136, 0.9857, 0.9729, 1.0590, 1.0305, 1.0069,
1.0403, 0.9882, 1.0400, 1.0210, 1.0071, 1.0199, 1.0361, 1.0103, 1.0321,
0.9971, 1.0242, 1.0043, 1.0083, 1.0358, 1.0249, 1.0248, 1.0253, 1.0145,
1.0308, 1.0350, 1.0026, 1.0243, 1.0622, 1.0130, 1.0450, 1.0352, 1.0044,
1.0198, 0.9947, 1.0206, 1.0624, 1.0011, 1.0171, 1.0067, 0.9833, 1.0265,
1.0177, 1.0020, 1.0103, 0.9927, 1.0338, 1.0246, 1.0028, 1.0184, 1.0532,
1.0212, 1.0513, 0.9979, 1.0056, 0.9834, 1.0311, 1.0111, 0.9987, 0.9418,
1.0497, 0.9920, 1.0564, 1.0277, 1.0527, 1.0265, 1.0009, 1.0173, 1.0130,
1.0417, 1.0247, 0.9964, 1.0203, 0.9906, 1.0352, 0.9994, 1.0124, 1.0124,
1.0030, 1.0344, 0.9895, 1.0337, 0.9933, 1.0251, 1.0052, 1.0388, 1.0371,
1.0128, 0.9976, 0.9893, 1.0142, 1.0612, 1.0193, 1.0249, 0.9843, 1.0064,
1.0315, 1.0404, 1.0377, 1.0080, 1.0284, 1.0408, 1.0599, 1.0578, 0.9919,
1.0388, 1.0239, 1.0189, 1.0336, 0.9965, 0.9937, 1.0287, 1.0425, 0.9853,
1.0257, 1.0320, 1.0446, 1.0142, 1.0003, 1.0219, 1.0120, 1.0285, 0.9889,
1.0396, 1.0238, 0.9810, 1.0061, 0.9414, 1.0253, 1.0192, 1.0393, 1.0319,
1.0012, 0.9929, 1.0029, 1.0078, 1.0248, 1.0323, 1.0434, 0.9189],
device=‘cuda:0’))

Could you explain a bit more what “extract the embeddings on the test images” would mean?
If you would like to get the forward activations for some inputs (i.e. your test images) you could use forward hooks as described here.

Hi @ptrblck
The pre-trained model I have is a transformer model and it is an ordered dictionary. Now I want to extract the embeddings from the MLP which is the penultimate layer before the loss functions. I have to extract the transformer embeddings from this MLP layer for my test images. I tried the hook functions, since pre-trained model is an ordered dictionary with key-value pairs it is showing error as “Tensor objects cannot have forward hooks”. How can I access the MLP layer to extract the embeddings. Here is the github link for the pre-trained model GitHub - zhongyy/Face-Transformer: Face Transformer for Recognition.

You would have to register the forward hook on a layer as given in my code snippet or you could alternatively return the desired tensor in a custom forward method. Based on the error message it seems you are trying to register the hook on a tensor, not a layer.

Hi @ptrblk Thanks for the reply
Without using pre-trained model and by removing the loss function and by using the hook functions on the layers, embeddings can be extracted of shape 512.
But with pre-trained model can’t we extract the MLP embeddings??

You still haven’t defined what your understanding of “embeddings” is. If you want to get the weight parameter of a specific layer, you can directly index it via model.layer.weight. On the other hand, if you want to get a forward activation from a specific layer, you can register a forward hook or just return the activation in the forward method (additionally to the last output).