How to implement an early exiting mechanism in a Neural network

Hi, I’m trying to implement a neural network that would be able to follow different paths according to its confidence score in the predictions, so for the “easy” data just the simpler layer is used and for the the data that don’t achieve the necessary confidence score we use the entire network.

I would like to use the confidence factor to compare if the network should continue computing or make a early exiting but I’m having some problems with this comparison (in the forward propagation part) and I’m not sure if my logic is correct too.

Here is my implementation:

    self.conv1 = nn.Sequential(         
        nn.Conv2d(
            in_channels=1,              
            out_channels=16,            
            kernel_size=5,              
            stride=1,                   
            padding=2,                  
        ),                              
        nn.ReLU(),                      
        nn.MaxPool2d(kernel_size=2),    
    )
    
    self.conv2 = nn.Sequential(         
        nn.Conv2d(16, 32, 5, 1, 2),     
        nn.ReLU(),                     
        nn.MaxPool2d(2),                
    )
    
    self.confidence1 = nn.Sequential(
        nn.Linear(16 * 14 * 14, 1),
        nn.Sigmoid(),
    )
    self.classifier1 = nn.Sequential(
        nn.Linear(16 * 14 * 14, 10),
        nn.Softmax(dim=1),
    )

    # fully connected layer, output 10 classes
    self.out1 = nn.Linear(16 * 14 * 14, 10)

    # Complete network
    self.out2 = nn.Linear(32 * 7 * 7  , 10) 
    

def forward(self, x):
      
      # Go through the first layer
      x = self.conv1(x)

      #Confidence and prediction
      x = x.view(x.size(0), -1)
      conf = self.confidence1(x)
      pred = self.classifier1(x)

      if (conf.item() > 0.5): 
        return pred

      # else goes through all layers
      x = self.conv2(x)
      x = x.view(x.size(0), -1)
      output = self.out2(x)

      return output

Any help is appreciated.

The condition will synchronize your code, so you might want to remove the item() code in case you are hitting performance issues.

Besides that, the returned values would have a different range as self.classifier is using a softmax layer while self.out2 is a plain linear layer.
I don’t know how you are planning to train this model, but I would recommend to check this correctness.

Also,in the “full” network pass self.conv2d would receive a flattened activation which would break due to shape mismatch.

1 Like

Thanks for the remarks, I corrected the shape issues.

I’m struggling to find some way where I can keep the training towards the entire network just for the parameters that don’t have the necessary confidence factor, while for the others I would skip the network using the prediction1 response as output.

Here is one of my outputs: (Confidence and prediction values)

Conf: tensor([[0.5005],
[0.4718],
[0.4872],
[0.5072],
[0.4926],
[0.4972],
[0.4870],
[0.4741],
[0.4993],
[0.4778],
[0.4534],
[0.4733],
[0.4882],
[0.4665],
[0.5044],
[0.4772],
[0.4403],
[0.4896],
[0.4819],
[0.5052],
[0.5027],
[0.5055],
[0.4560],
[0.4652],
[0.4468],
[0.4912],
[0.4780],
[0.5073],
[0.4424],
[0.4987],
[0.4707],
[0.4368],
[0.5535],
[0.4541],
[0.4705],
[0.4757],
[0.4915],
[0.4699],
[0.5004],
[0.4482],
[0.4416],
[0.4536],
[0.4647],
[0.4955],
[0.5073],
[0.4747],
[0.4856],
[0.4811],
[0.5065],
[0.5092],
[0.5050],
[0.5114],
[0.4810],
[0.4393],
[0.4722],
[0.4862],
[0.4553],
[0.4814],
[0.4841],
[0.5039],
[0.4837],
[0.4835],
[0.4797],
[0.5005],
[0.4784],
[0.4981],
[0.4992],
[0.5059],
[0.4935],
[0.4716],
[0.4741],
[0.4379],
[0.5091],
[0.4794],
[0.4929],
[0.4620],
[0.4526],
[0.4941],
[0.4537],
[0.5121],
[0.5032],
[0.4723],
[0.4958],
[0.4663],
[0.5164],
[0.4697],
[0.5045],
[0.4950],
[0.4547],
[0.4561],
[0.4526],
[0.4773],
[0.4735],
[0.4515],
[0.4724],
[0.4979],
[0.5048],
[0.5185],
[0.4990],
[0.5244]])
Pred: tensor([[0.1105, 0.0857, 0.0984, 0.0920, 0.1039, 0.1064, 0.1122, 0.0969, 0.0880,
0.1060],
[0.1010, 0.0860, 0.1045, 0.0978, 0.0930, 0.0990, 0.1080, 0.0979, 0.1034,
0.1094],
[0.1003, 0.0888, 0.1083, 0.1014, 0.0909, 0.1064, 0.0841, 0.1023, 0.1105,
0.1069],
[0.1017, 0.0917, 0.0943, 0.0922, 0.0894, 0.1195, 0.1186, 0.1073, 0.0898,
0.0954],
[0.1110, 0.0872, 0.0932, 0.1178, 0.0988, 0.0968, 0.1063, 0.0936, 0.1028,
0.0926],
[0.1026, 0.0969, 0.0998, 0.1034, 0.0970, 0.1094, 0.1043, 0.0981, 0.0756,
0.1131],
[0.1131, 0.0912, 0.1078, 0.1048, 0.0861, 0.0944, 0.0965, 0.0923, 0.0978,
0.1160],
[0.1010, 0.1110, 0.1055, 0.0885, 0.0906, 0.1076, 0.1320, 0.0759, 0.0874,
0.1006],
[0.1009, 0.0820, 0.1090, 0.1133, 0.0978, 0.1084, 0.0903, 0.0922, 0.1009,
0.1051],
[0.0997, 0.0958, 0.1026, 0.0993, 0.0949, 0.0911, 0.1047, 0.0877, 0.1027,
0.1214],
[0.1118, 0.0833, 0.1052, 0.1128, 0.0986, 0.0910, 0.0834, 0.0931, 0.1003,
0.1206],
[0.0995, 0.1037, 0.0992, 0.0983, 0.0958, 0.1068, 0.0998, 0.1073, 0.0965,
0.0930],
[0.1078, 0.1009, 0.0899, 0.1038, 0.0836, 0.1081, 0.1113, 0.1010, 0.0978,
0.0958],
[0.1021, 0.0947, 0.0997, 0.0980, 0.0801, 0.1065, 0.1127, 0.0983, 0.0968,
0.1111],
[0.0951, 0.0898, 0.0996, 0.1007, 0.0856, 0.1066, 0.1169, 0.0990, 0.0980,
0.1086],
[0.0945, 0.0903, 0.0918, 0.1032, 0.0826, 0.1137, 0.1123, 0.1031, 0.1129,
0.0957],
[0.1124, 0.1025, 0.0984, 0.1011, 0.1028, 0.0867, 0.1061, 0.0868, 0.1023,
0.1010],
[0.1034, 0.0878, 0.1072, 0.1030, 0.0942, 0.0985, 0.1052, 0.0860, 0.1084,
0.1064],
[0.0970, 0.0942, 0.1059, 0.1014, 0.0889, 0.1032, 0.1002, 0.0892, 0.1003,
0.1197],
[0.0999, 0.0847, 0.1098, 0.0934, 0.0941, 0.0948, 0.1118, 0.0933, 0.1089,
0.1092],
[0.1045, 0.0954, 0.1039, 0.1075, 0.0827, 0.0993, 0.1355, 0.0847, 0.0879,
0.0984],
[0.0973, 0.0883, 0.0938, 0.1057, 0.0935, 0.1019, 0.1207, 0.1035, 0.0915,
0.1037],
[0.1041, 0.1040, 0.1034, 0.0944, 0.0902, 0.0894, 0.1130, 0.0960, 0.0996,
0.1058],
[0.1072, 0.0968, 0.0995, 0.0966, 0.0851, 0.1004, 0.1110, 0.0994, 0.0931,
0.1109],
[0.1106, 0.0922, 0.0964, 0.1066, 0.0863, 0.1076, 0.0875, 0.0911, 0.1086,
0.1131],
[0.1181, 0.0947, 0.0978, 0.1129, 0.0833, 0.1045, 0.1206, 0.0806, 0.0920,
0.0956],
[0.0968, 0.0968, 0.1064, 0.0972, 0.0881, 0.1037, 0.1017, 0.0924, 0.1027,
0.1143],
[0.1142, 0.0927, 0.1159, 0.1099, 0.0858, 0.0981, 0.1002, 0.0838, 0.0887,
0.1108],
[0.1001, 0.0918, 0.1013, 0.0952, 0.1064, 0.0927, 0.1025, 0.0937, 0.1064,
0.1100],
[0.1125, 0.1161, 0.1037, 0.0886, 0.0854, 0.1014, 0.1186, 0.0886, 0.0847,
0.1003],
[0.1097, 0.0802, 0.0983, 0.1091, 0.1030, 0.0981, 0.1163, 0.0880, 0.0989,
0.0983],
[0.0979, 0.0986, 0.0937, 0.0933, 0.0756, 0.1185, 0.1169, 0.0935, 0.0940,
0.1179],
[0.0943, 0.0975, 0.1016, 0.0894, 0.0883, 0.1277, 0.1087, 0.0871, 0.0982,
0.1071],
[0.1238, 0.0829, 0.1029, 0.1138, 0.1038, 0.0917, 0.0936, 0.0818, 0.0857,
0.1201],
[0.1038, 0.0965, 0.1029, 0.0998, 0.1006, 0.0907, 0.1084, 0.1022, 0.0929,
0.1022],
[0.1152, 0.0910, 0.0980, 0.0958, 0.0898, 0.1071, 0.1098, 0.0971, 0.0871,
0.1091],
[0.1162, 0.0801, 0.1153, 0.1051, 0.0904, 0.0996, 0.0963, 0.0912, 0.0989,
0.1069],
[0.1036, 0.1158, 0.0975, 0.0898, 0.0858, 0.1273, 0.1003, 0.0822, 0.0928,
0.1050],
[0.1059, 0.0983, 0.1135, 0.0907, 0.0937, 0.1101, 0.1157, 0.0832, 0.0949,
0.0939],
[0.0973, 0.1059, 0.1153, 0.0944, 0.0860, 0.0969, 0.1254, 0.0869, 0.0888,
0.1033],
[0.1074, 0.0986, 0.1084, 0.1097, 0.0907, 0.0910, 0.0855, 0.0854, 0.1069,
0.1163],
[0.1037, 0.1109, 0.0947, 0.0952, 0.0927, 0.0898, 0.1193, 0.0948, 0.0938,
0.1052],
[0.1117, 0.0894, 0.0937, 0.0978, 0.0902, 0.0926, 0.1040, 0.0980, 0.0978,
0.1247],
[0.1073, 0.0884, 0.1075, 0.1094, 0.0809, 0.1022, 0.0997, 0.0957, 0.1103,
0.0986],
[0.1042, 0.0791, 0.1116, 0.1130, 0.0948, 0.1004, 0.0958, 0.0974, 0.0993,
0.1045],
[0.1163, 0.0983, 0.0997, 0.1124, 0.0849, 0.0860, 0.1092, 0.0803, 0.1003,
0.1126],
[0.0964, 0.0962, 0.0948, 0.1040, 0.0924, 0.0991, 0.1209, 0.0965, 0.0930,
0.1067],
[0.1044, 0.1028, 0.0937, 0.0983, 0.0760, 0.0986, 0.1346, 0.0901, 0.0970,
0.1045],
[0.0919, 0.0923, 0.1057, 0.1045, 0.1054, 0.0938, 0.1092, 0.1050, 0.0809,
0.1111],
[0.1022, 0.0861, 0.1039, 0.0954, 0.0971, 0.0951, 0.1030, 0.1070, 0.1047,
0.1054],
[0.1085, 0.0818, 0.0929, 0.0916, 0.0884, 0.1085, 0.1057, 0.1085, 0.1021,
0.1120],
[0.0988, 0.0888, 0.1046, 0.1037, 0.0895, 0.0936, 0.1016, 0.0953, 0.1025,
0.1214],
[0.1022, 0.0823, 0.1131, 0.1001, 0.0916, 0.1176, 0.0946, 0.0835, 0.1092,
0.1059],
[0.1089, 0.1005, 0.1041, 0.0939, 0.0982, 0.1090, 0.1069, 0.0892, 0.0850,
0.1042],
[0.1125, 0.0901, 0.0997, 0.0962, 0.0897, 0.1011, 0.1089, 0.0861, 0.1118,
0.1040],
[0.0985, 0.1001, 0.0934, 0.0920, 0.0772, 0.1127, 0.1195, 0.0819, 0.1003,
0.1245],
[0.1063, 0.1076, 0.1250, 0.0885, 0.0974, 0.0932, 0.1216, 0.0926, 0.0839,
0.0839],
[0.1244, 0.0952, 0.1065, 0.1054, 0.0949, 0.0864, 0.0870, 0.0929, 0.1023,
0.1050],
[0.1271, 0.0877, 0.1035, 0.0897, 0.1024, 0.0944, 0.0997, 0.0881, 0.0932,
0.1143],
[0.0970, 0.0918, 0.1061, 0.0966, 0.0915, 0.1096, 0.1184, 0.0814, 0.1039,
0.1037],
[0.0930, 0.0944, 0.0826, 0.0986, 0.0994, 0.1061, 0.1077, 0.0916, 0.1034,
0.1231],
[0.1186, 0.1110, 0.0993, 0.1073, 0.0817, 0.0921, 0.0918, 0.0829, 0.0943,
0.1211],
[0.1083, 0.0940, 0.1084, 0.1085, 0.0889, 0.1048, 0.1003, 0.0850, 0.0885,
0.1133],
[0.1050, 0.0876, 0.0833, 0.1124, 0.0844, 0.1017, 0.1135, 0.0801, 0.1059,
0.1262],
[0.1130, 0.0964, 0.1002, 0.1031, 0.0872, 0.0878, 0.1229, 0.0848, 0.0938,
0.1109],
[0.0927, 0.0948, 0.1013, 0.0915, 0.1004, 0.0984, 0.1017, 0.1015, 0.1029,
0.1147],
[0.0904, 0.0916, 0.1028, 0.1114, 0.0916, 0.0995, 0.1198, 0.1037, 0.0947,
0.0946],
[0.1021, 0.0967, 0.0925, 0.0938, 0.0788, 0.1146, 0.1172, 0.0919, 0.1058,
0.1066],
[0.0938, 0.0910, 0.0885, 0.0967, 0.1171, 0.1057, 0.0895, 0.0985, 0.1023,
0.1169],
[0.1001, 0.0933, 0.0970, 0.1218, 0.0838, 0.1008, 0.0938, 0.0877, 0.1057,
0.1161],
[0.1034, 0.0947, 0.0920, 0.1015, 0.0968, 0.0967, 0.1030, 0.0979, 0.1067,
0.1073],
[0.1147, 0.0975, 0.1206, 0.1084, 0.0932, 0.0828, 0.1041, 0.0840, 0.0999,
0.0947],
[0.1004, 0.0898, 0.0873, 0.0830, 0.0882, 0.1259, 0.1170, 0.0911, 0.0928,
0.1245],
[0.1043, 0.0872, 0.1040, 0.0979, 0.0922, 0.1004, 0.1097, 0.0913, 0.0903,
0.1227],
[0.1013, 0.0921, 0.1068, 0.1006, 0.0898, 0.1010, 0.0985, 0.0947, 0.1020,
0.1133],
[0.1303, 0.1179, 0.1015, 0.0968, 0.0878, 0.0852, 0.0897, 0.0892, 0.0960,
0.1057],
[0.0983, 0.1008, 0.1189, 0.0989, 0.0898, 0.1136, 0.1084, 0.0850, 0.0801,
0.1063],
[0.0933, 0.0907, 0.0940, 0.0961, 0.0859, 0.1150, 0.1193, 0.0916, 0.1059,
0.1082],
[0.0963, 0.0972, 0.1057, 0.0944, 0.0928, 0.1059, 0.1043, 0.0896, 0.1022,
0.1117],
[0.1023, 0.0958, 0.1029, 0.1066, 0.0862, 0.1032, 0.1208, 0.0908, 0.0903,
0.1010],
[0.0994, 0.0842, 0.1023, 0.1008, 0.1009, 0.1143, 0.1091, 0.0880, 0.0984,
0.1025],
[0.0927, 0.0846, 0.0897, 0.0904, 0.0930, 0.1119, 0.1197, 0.1015, 0.1028,
0.1136],
[0.1034, 0.0925, 0.0927, 0.0947, 0.0896, 0.1139, 0.0983, 0.0853, 0.1007,
0.1288],
[0.1224, 0.0861, 0.1013, 0.1045, 0.0853, 0.0989, 0.1038, 0.1036, 0.0789,
0.1151],
[0.0997, 0.0804, 0.1013, 0.0947, 0.0918, 0.0992, 0.0988, 0.1058, 0.1195,
0.1088],
[0.1001, 0.0731, 0.0846, 0.1081, 0.0866, 0.1174, 0.1019, 0.1071, 0.1034,
0.1178],
[0.1087, 0.0910, 0.1057, 0.0991, 0.0850, 0.1091, 0.1009, 0.0856, 0.0944,
0.1204],
[0.1155, 0.1000, 0.0968, 0.1053, 0.1170, 0.0892, 0.0955, 0.0862, 0.0893,
0.1051],
[0.1112, 0.0946, 0.0982, 0.1041, 0.0903, 0.1072, 0.1088, 0.0909, 0.0959,
0.0988],
[0.1076, 0.0939, 0.1066, 0.0976, 0.0975, 0.1001, 0.1038, 0.1028, 0.0844,
0.1056],
[0.1027, 0.0992, 0.1043, 0.1000, 0.0949, 0.0954, 0.1014, 0.0997, 0.0956,
0.1066],
[0.1102, 0.0847, 0.1120, 0.0979, 0.0912, 0.0950, 0.1073, 0.0991, 0.0975,
0.1051],
[0.1084, 0.0845, 0.1181, 0.1088, 0.0883, 0.0911, 0.0983, 0.1032, 0.1016,
0.0976],
[0.1105, 0.0881, 0.0957, 0.1029, 0.0870, 0.1057, 0.1048, 0.1015, 0.0968,
0.1071],
[0.1049, 0.0874, 0.0957, 0.0938, 0.0945, 0.1114, 0.1115, 0.0885, 0.0991,
0.1133],
[0.1048, 0.0855, 0.1196, 0.1000, 0.0970, 0.1034, 0.0996, 0.0895, 0.0933,
0.1073],
[0.1054, 0.0989, 0.0925, 0.0971, 0.0858, 0.1006, 0.1170, 0.1117, 0.0962,
0.0948],
[0.1025, 0.1067, 0.1050, 0.1031, 0.0991, 0.0926, 0.1345, 0.0791, 0.0869,
0.0905],
[0.1032, 0.0780, 0.0983, 0.0934, 0.0955, 0.1263, 0.1068, 0.0923, 0.0986,
0.1074],
[0.1031, 0.0886, 0.0963, 0.1050, 0.1003, 0.0989, 0.1026, 0.1094, 0.0843,
0.1114]])

So, in this output for example I would like to exit the training of every parameter that have a conf value higher than 0.5, while keeping the remaining ones in the “full network”. Is there some way that I can do this?

Perhaps you can dynamically iterate over your parameters during the training loop step, after you call the criterion but before you call .backward(), and zero out the particular value of the parameter tensors’ .grad attribute in places where you don’t want to update the parameters (confidence is high enough).

1 Like