Thanks for the remarks, I corrected the shape issues.
I’m struggling to find some way where I can keep the training towards the entire network just for the parameters that don’t have the necessary confidence factor, while for the others I would skip the network using the prediction1 response as output.
Conf: tensor([[0.5005],
[0.4718],
[0.4872],
[0.5072],
[0.4926],
[0.4972],
[0.4870],
[0.4741],
[0.4993],
[0.4778],
[0.4534],
[0.4733],
[0.4882],
[0.4665],
[0.5044],
[0.4772],
[0.4403],
[0.4896],
[0.4819],
[0.5052],
[0.5027],
[0.5055],
[0.4560],
[0.4652],
[0.4468],
[0.4912],
[0.4780],
[0.5073],
[0.4424],
[0.4987],
[0.4707],
[0.4368],
[0.5535],
[0.4541],
[0.4705],
[0.4757],
[0.4915],
[0.4699],
[0.5004],
[0.4482],
[0.4416],
[0.4536],
[0.4647],
[0.4955],
[0.5073],
[0.4747],
[0.4856],
[0.4811],
[0.5065],
[0.5092],
[0.5050],
[0.5114],
[0.4810],
[0.4393],
[0.4722],
[0.4862],
[0.4553],
[0.4814],
[0.4841],
[0.5039],
[0.4837],
[0.4835],
[0.4797],
[0.5005],
[0.4784],
[0.4981],
[0.4992],
[0.5059],
[0.4935],
[0.4716],
[0.4741],
[0.4379],
[0.5091],
[0.4794],
[0.4929],
[0.4620],
[0.4526],
[0.4941],
[0.4537],
[0.5121],
[0.5032],
[0.4723],
[0.4958],
[0.4663],
[0.5164],
[0.4697],
[0.5045],
[0.4950],
[0.4547],
[0.4561],
[0.4526],
[0.4773],
[0.4735],
[0.4515],
[0.4724],
[0.4979],
[0.5048],
[0.5185],
[0.4990],
[0.5244]])
Pred: tensor([[0.1105, 0.0857, 0.0984, 0.0920, 0.1039, 0.1064, 0.1122, 0.0969, 0.0880,
0.1060],
[0.1010, 0.0860, 0.1045, 0.0978, 0.0930, 0.0990, 0.1080, 0.0979, 0.1034,
0.1094],
[0.1003, 0.0888, 0.1083, 0.1014, 0.0909, 0.1064, 0.0841, 0.1023, 0.1105,
0.1069],
[0.1017, 0.0917, 0.0943, 0.0922, 0.0894, 0.1195, 0.1186, 0.1073, 0.0898,
0.0954],
[0.1110, 0.0872, 0.0932, 0.1178, 0.0988, 0.0968, 0.1063, 0.0936, 0.1028,
0.0926],
[0.1026, 0.0969, 0.0998, 0.1034, 0.0970, 0.1094, 0.1043, 0.0981, 0.0756,
0.1131],
[0.1131, 0.0912, 0.1078, 0.1048, 0.0861, 0.0944, 0.0965, 0.0923, 0.0978,
0.1160],
[0.1010, 0.1110, 0.1055, 0.0885, 0.0906, 0.1076, 0.1320, 0.0759, 0.0874,
0.1006],
[0.1009, 0.0820, 0.1090, 0.1133, 0.0978, 0.1084, 0.0903, 0.0922, 0.1009,
0.1051],
[0.0997, 0.0958, 0.1026, 0.0993, 0.0949, 0.0911, 0.1047, 0.0877, 0.1027,
0.1214],
[0.1118, 0.0833, 0.1052, 0.1128, 0.0986, 0.0910, 0.0834, 0.0931, 0.1003,
0.1206],
[0.0995, 0.1037, 0.0992, 0.0983, 0.0958, 0.1068, 0.0998, 0.1073, 0.0965,
0.0930],
[0.1078, 0.1009, 0.0899, 0.1038, 0.0836, 0.1081, 0.1113, 0.1010, 0.0978,
0.0958],
[0.1021, 0.0947, 0.0997, 0.0980, 0.0801, 0.1065, 0.1127, 0.0983, 0.0968,
0.1111],
[0.0951, 0.0898, 0.0996, 0.1007, 0.0856, 0.1066, 0.1169, 0.0990, 0.0980,
0.1086],
[0.0945, 0.0903, 0.0918, 0.1032, 0.0826, 0.1137, 0.1123, 0.1031, 0.1129,
0.0957],
[0.1124, 0.1025, 0.0984, 0.1011, 0.1028, 0.0867, 0.1061, 0.0868, 0.1023,
0.1010],
[0.1034, 0.0878, 0.1072, 0.1030, 0.0942, 0.0985, 0.1052, 0.0860, 0.1084,
0.1064],
[0.0970, 0.0942, 0.1059, 0.1014, 0.0889, 0.1032, 0.1002, 0.0892, 0.1003,
0.1197],
[0.0999, 0.0847, 0.1098, 0.0934, 0.0941, 0.0948, 0.1118, 0.0933, 0.1089,
0.1092],
[0.1045, 0.0954, 0.1039, 0.1075, 0.0827, 0.0993, 0.1355, 0.0847, 0.0879,
0.0984],
[0.0973, 0.0883, 0.0938, 0.1057, 0.0935, 0.1019, 0.1207, 0.1035, 0.0915,
0.1037],
[0.1041, 0.1040, 0.1034, 0.0944, 0.0902, 0.0894, 0.1130, 0.0960, 0.0996,
0.1058],
[0.1072, 0.0968, 0.0995, 0.0966, 0.0851, 0.1004, 0.1110, 0.0994, 0.0931,
0.1109],
[0.1106, 0.0922, 0.0964, 0.1066, 0.0863, 0.1076, 0.0875, 0.0911, 0.1086,
0.1131],
[0.1181, 0.0947, 0.0978, 0.1129, 0.0833, 0.1045, 0.1206, 0.0806, 0.0920,
0.0956],
[0.0968, 0.0968, 0.1064, 0.0972, 0.0881, 0.1037, 0.1017, 0.0924, 0.1027,
0.1143],
[0.1142, 0.0927, 0.1159, 0.1099, 0.0858, 0.0981, 0.1002, 0.0838, 0.0887,
0.1108],
[0.1001, 0.0918, 0.1013, 0.0952, 0.1064, 0.0927, 0.1025, 0.0937, 0.1064,
0.1100],
[0.1125, 0.1161, 0.1037, 0.0886, 0.0854, 0.1014, 0.1186, 0.0886, 0.0847,
0.1003],
[0.1097, 0.0802, 0.0983, 0.1091, 0.1030, 0.0981, 0.1163, 0.0880, 0.0989,
0.0983],
[0.0979, 0.0986, 0.0937, 0.0933, 0.0756, 0.1185, 0.1169, 0.0935, 0.0940,
0.1179],
[0.0943, 0.0975, 0.1016, 0.0894, 0.0883, 0.1277, 0.1087, 0.0871, 0.0982,
0.1071],
[0.1238, 0.0829, 0.1029, 0.1138, 0.1038, 0.0917, 0.0936, 0.0818, 0.0857,
0.1201],
[0.1038, 0.0965, 0.1029, 0.0998, 0.1006, 0.0907, 0.1084, 0.1022, 0.0929,
0.1022],
[0.1152, 0.0910, 0.0980, 0.0958, 0.0898, 0.1071, 0.1098, 0.0971, 0.0871,
0.1091],
[0.1162, 0.0801, 0.1153, 0.1051, 0.0904, 0.0996, 0.0963, 0.0912, 0.0989,
0.1069],
[0.1036, 0.1158, 0.0975, 0.0898, 0.0858, 0.1273, 0.1003, 0.0822, 0.0928,
0.1050],
[0.1059, 0.0983, 0.1135, 0.0907, 0.0937, 0.1101, 0.1157, 0.0832, 0.0949,
0.0939],
[0.0973, 0.1059, 0.1153, 0.0944, 0.0860, 0.0969, 0.1254, 0.0869, 0.0888,
0.1033],
[0.1074, 0.0986, 0.1084, 0.1097, 0.0907, 0.0910, 0.0855, 0.0854, 0.1069,
0.1163],
[0.1037, 0.1109, 0.0947, 0.0952, 0.0927, 0.0898, 0.1193, 0.0948, 0.0938,
0.1052],
[0.1117, 0.0894, 0.0937, 0.0978, 0.0902, 0.0926, 0.1040, 0.0980, 0.0978,
0.1247],
[0.1073, 0.0884, 0.1075, 0.1094, 0.0809, 0.1022, 0.0997, 0.0957, 0.1103,
0.0986],
[0.1042, 0.0791, 0.1116, 0.1130, 0.0948, 0.1004, 0.0958, 0.0974, 0.0993,
0.1045],
[0.1163, 0.0983, 0.0997, 0.1124, 0.0849, 0.0860, 0.1092, 0.0803, 0.1003,
0.1126],
[0.0964, 0.0962, 0.0948, 0.1040, 0.0924, 0.0991, 0.1209, 0.0965, 0.0930,
0.1067],
[0.1044, 0.1028, 0.0937, 0.0983, 0.0760, 0.0986, 0.1346, 0.0901, 0.0970,
0.1045],
[0.0919, 0.0923, 0.1057, 0.1045, 0.1054, 0.0938, 0.1092, 0.1050, 0.0809,
0.1111],
[0.1022, 0.0861, 0.1039, 0.0954, 0.0971, 0.0951, 0.1030, 0.1070, 0.1047,
0.1054],
[0.1085, 0.0818, 0.0929, 0.0916, 0.0884, 0.1085, 0.1057, 0.1085, 0.1021,
0.1120],
[0.0988, 0.0888, 0.1046, 0.1037, 0.0895, 0.0936, 0.1016, 0.0953, 0.1025,
0.1214],
[0.1022, 0.0823, 0.1131, 0.1001, 0.0916, 0.1176, 0.0946, 0.0835, 0.1092,
0.1059],
[0.1089, 0.1005, 0.1041, 0.0939, 0.0982, 0.1090, 0.1069, 0.0892, 0.0850,
0.1042],
[0.1125, 0.0901, 0.0997, 0.0962, 0.0897, 0.1011, 0.1089, 0.0861, 0.1118,
0.1040],
[0.0985, 0.1001, 0.0934, 0.0920, 0.0772, 0.1127, 0.1195, 0.0819, 0.1003,
0.1245],
[0.1063, 0.1076, 0.1250, 0.0885, 0.0974, 0.0932, 0.1216, 0.0926, 0.0839,
0.0839],
[0.1244, 0.0952, 0.1065, 0.1054, 0.0949, 0.0864, 0.0870, 0.0929, 0.1023,
0.1050],
[0.1271, 0.0877, 0.1035, 0.0897, 0.1024, 0.0944, 0.0997, 0.0881, 0.0932,
0.1143],
[0.0970, 0.0918, 0.1061, 0.0966, 0.0915, 0.1096, 0.1184, 0.0814, 0.1039,
0.1037],
[0.0930, 0.0944, 0.0826, 0.0986, 0.0994, 0.1061, 0.1077, 0.0916, 0.1034,
0.1231],
[0.1186, 0.1110, 0.0993, 0.1073, 0.0817, 0.0921, 0.0918, 0.0829, 0.0943,
0.1211],
[0.1083, 0.0940, 0.1084, 0.1085, 0.0889, 0.1048, 0.1003, 0.0850, 0.0885,
0.1133],
[0.1050, 0.0876, 0.0833, 0.1124, 0.0844, 0.1017, 0.1135, 0.0801, 0.1059,
0.1262],
[0.1130, 0.0964, 0.1002, 0.1031, 0.0872, 0.0878, 0.1229, 0.0848, 0.0938,
0.1109],
[0.0927, 0.0948, 0.1013, 0.0915, 0.1004, 0.0984, 0.1017, 0.1015, 0.1029,
0.1147],
[0.0904, 0.0916, 0.1028, 0.1114, 0.0916, 0.0995, 0.1198, 0.1037, 0.0947,
0.0946],
[0.1021, 0.0967, 0.0925, 0.0938, 0.0788, 0.1146, 0.1172, 0.0919, 0.1058,
0.1066],
[0.0938, 0.0910, 0.0885, 0.0967, 0.1171, 0.1057, 0.0895, 0.0985, 0.1023,
0.1169],
[0.1001, 0.0933, 0.0970, 0.1218, 0.0838, 0.1008, 0.0938, 0.0877, 0.1057,
0.1161],
[0.1034, 0.0947, 0.0920, 0.1015, 0.0968, 0.0967, 0.1030, 0.0979, 0.1067,
0.1073],
[0.1147, 0.0975, 0.1206, 0.1084, 0.0932, 0.0828, 0.1041, 0.0840, 0.0999,
0.0947],
[0.1004, 0.0898, 0.0873, 0.0830, 0.0882, 0.1259, 0.1170, 0.0911, 0.0928,
0.1245],
[0.1043, 0.0872, 0.1040, 0.0979, 0.0922, 0.1004, 0.1097, 0.0913, 0.0903,
0.1227],
[0.1013, 0.0921, 0.1068, 0.1006, 0.0898, 0.1010, 0.0985, 0.0947, 0.1020,
0.1133],
[0.1303, 0.1179, 0.1015, 0.0968, 0.0878, 0.0852, 0.0897, 0.0892, 0.0960,
0.1057],
[0.0983, 0.1008, 0.1189, 0.0989, 0.0898, 0.1136, 0.1084, 0.0850, 0.0801,
0.1063],
[0.0933, 0.0907, 0.0940, 0.0961, 0.0859, 0.1150, 0.1193, 0.0916, 0.1059,
0.1082],
[0.0963, 0.0972, 0.1057, 0.0944, 0.0928, 0.1059, 0.1043, 0.0896, 0.1022,
0.1117],
[0.1023, 0.0958, 0.1029, 0.1066, 0.0862, 0.1032, 0.1208, 0.0908, 0.0903,
0.1010],
[0.0994, 0.0842, 0.1023, 0.1008, 0.1009, 0.1143, 0.1091, 0.0880, 0.0984,
0.1025],
[0.0927, 0.0846, 0.0897, 0.0904, 0.0930, 0.1119, 0.1197, 0.1015, 0.1028,
0.1136],
[0.1034, 0.0925, 0.0927, 0.0947, 0.0896, 0.1139, 0.0983, 0.0853, 0.1007,
0.1288],
[0.1224, 0.0861, 0.1013, 0.1045, 0.0853, 0.0989, 0.1038, 0.1036, 0.0789,
0.1151],
[0.0997, 0.0804, 0.1013, 0.0947, 0.0918, 0.0992, 0.0988, 0.1058, 0.1195,
0.1088],
[0.1001, 0.0731, 0.0846, 0.1081, 0.0866, 0.1174, 0.1019, 0.1071, 0.1034,
0.1178],
[0.1087, 0.0910, 0.1057, 0.0991, 0.0850, 0.1091, 0.1009, 0.0856, 0.0944,
0.1204],
[0.1155, 0.1000, 0.0968, 0.1053, 0.1170, 0.0892, 0.0955, 0.0862, 0.0893,
0.1051],
[0.1112, 0.0946, 0.0982, 0.1041, 0.0903, 0.1072, 0.1088, 0.0909, 0.0959,
0.0988],
[0.1076, 0.0939, 0.1066, 0.0976, 0.0975, 0.1001, 0.1038, 0.1028, 0.0844,
0.1056],
[0.1027, 0.0992, 0.1043, 0.1000, 0.0949, 0.0954, 0.1014, 0.0997, 0.0956,
0.1066],
[0.1102, 0.0847, 0.1120, 0.0979, 0.0912, 0.0950, 0.1073, 0.0991, 0.0975,
0.1051],
[0.1084, 0.0845, 0.1181, 0.1088, 0.0883, 0.0911, 0.0983, 0.1032, 0.1016,
0.0976],
[0.1105, 0.0881, 0.0957, 0.1029, 0.0870, 0.1057, 0.1048, 0.1015, 0.0968,
0.1071],
[0.1049, 0.0874, 0.0957, 0.0938, 0.0945, 0.1114, 0.1115, 0.0885, 0.0991,
0.1133],
[0.1048, 0.0855, 0.1196, 0.1000, 0.0970, 0.1034, 0.0996, 0.0895, 0.0933,
0.1073],
[0.1054, 0.0989, 0.0925, 0.0971, 0.0858, 0.1006, 0.1170, 0.1117, 0.0962,
0.0948],
[0.1025, 0.1067, 0.1050, 0.1031, 0.0991, 0.0926, 0.1345, 0.0791, 0.0869,
0.0905],
[0.1032, 0.0780, 0.0983, 0.0934, 0.0955, 0.1263, 0.1068, 0.0923, 0.0986,
0.1074],
[0.1031, 0.0886, 0.0963, 0.1050, 0.1003, 0.0989, 0.1026, 0.1094, 0.0843,
0.1114]])
So, in this output for example I would like to exit the training of every parameter that have a conf value higher than 0.5, while keeping the remaining ones in the “full network”. Is there some way that I can do this?