First, I followed the steps in this discussion, so that the results are reproducible with different num_workers
.
using exact steps as here, which has worker_init_fn
as follows:
def _init_fn(worker_id):
# Not necessary since we will reinitialize the internal generator state of the worker at EVERY SAMPLE!
pass
The seed function at the start:
def set_seed(seed):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
Save a checkpoint for every epoch including the random generator states
checkpoint = {
‘cpu_rng_state’: torch.get_rng_state(),
‘gpu_rng_state’: torch.cuda.get_rng_state(),
‘gpu_rng_state_all’:torch.cuda.get_rng_state_all(),
‘numpy_rng_state’: np.random.get_state(),
‘py_rng_state’: random.getstate()
}
And to resume training from specific epoch, I restore the random generator states from saved checkpoint of that epoch.
Now if I don’t use persistent_workers (i.e., persistent_workers = False
):
- If I run the train process again from scratch, I can successfully reproduce the same results with num_workers >=0 for all epochs.
- And, I can successfully resume training and reproduce the same results for all epochs.
For example, the batches for epoch 6 are as follows, and if I used the checkpoint of epoch 5, I can successfully regenerate same batches for epoch 6 again:
batch: 0 tensor([502, 437, 150, 787, 320, 303, 154, 939, 68, 899, 657, 716, 85, 484,
334, 443, 734, 91, 473, 324, 733, 235, 75, 885, 739, 249, 704, 588,
563, 825, 516, 820, 110, 242, 266, 100, 384, 638, 212, 709, 785, 807,
295, 514, 460, 702, 207, 458, 415, 304, 653, 62, 104, 756, 348, 486,
221, 621, 754, 500, 456, 871, 333, 421])
batch: 1 tensor([913, 595, 386, 122, 64, 191, 428, 634, 679, 12, 206, 594, 214, 633,
849, 468, 18, 699, 24, 21, 248, 190, 884, 343, 940, 239, 124, 627,
43, 908, 536, 329, 654, 413, 332, 253, 390, 843, 316, 280, 385, 159,
234, 414, 137, 722, 710, 900, 222, 480, 344, 795, 664, 775, 335, 707,
746, 435, 326, 349, 815, 895, 289, 237])
batch: 2 tensor([944, 696, 678, 545, 455, 42, 891, 5, 463, 452, 59, 771, 128, 914,
751, 854, 101, 158, 592, 498, 524, 27, 394, 600, 765, 446, 391, 691,
625, 232, 923, 850, 184, 635, 95, 245, 540, 71, 29, 372, 649, 282,
360, 23, 290, 828, 50, 802, 194, 220, 658, 227, 142, 178, 140, 66,
596, 438, 10, 300, 697, 723, 816, 505])
batch: 3 tensor([389, 80, 862, 407, 598, 323, 430, 629, 669, 671, 844, 267, 681, 301,
201, 713, 168, 752, 448, 434, 856, 48, 467, 809, 155, 138, 37, 931,
781, 367, 44, 97, 271, 336, 96, 800, 476, 630, 404, 472, 770, 177,
205, 400, 766, 51, 322, 614, 742, 661, 520, 538, 149, 492, 162, 879,
575, 83, 924, 466, 925, 577, 706, 761])
batch: 4 tensor([509, 757, 318, 942, 327, 726, 53, 52, 417, 419, 792, 585, 909, 684,
873, 72, 799, 130, 935, 687, 283, 576, 380, 859, 228, 258, 832, 203,
427, 308, 405, 583, 915, 273, 406, 778, 624, 676, 321, 571, 515, 209,
608, 296, 172, 623, 418, 535, 487, 773, 555, 294, 772, 257, 260, 474,
744, 202, 560, 449, 721, 338, 462, 493])
batch: 5 tensor([587, 477, 930, 776, 928, 805, 113, 173, 783, 226, 156, 652, 672, 262,
808, 88, 632, 408, 933, 864, 65, 481, 745, 453, 846, 837, 636, 760,
397, 118, 817, 169, 798, 14, 233, 806, 167, 517, 708, 93, 701, 877,
523, 41, 838, 693, 642, 645, 131, 340, 357, 61, 6, 920, 801, 215,
852, 789, 298, 84, 762, 219, 244, 98])
batch: 6 tensor([567, 628, 910, 615, 160, 840, 606, 728, 532, 558, 1, 528, 170, 865,
610, 875, 442, 412, 788, 119, 164, 204, 165, 650, 860, 839, 192, 395,
8, 932, 579, 129, 715, 127, 355, 626, 512, 2, 265, 293, 365, 31,
857, 279, 639, 847, 163, 714, 803, 106, 63, 123, 759, 291, 74, 224,
116, 79, 562, 619, 197, 272, 315, 217])
batch: 7 tensor([496, 461, 252, 937, 182, 121, 906, 749, 926, 921, 522, 470, 58, 551,
897, 550, 552, 901, 319, 581, 306, 425, 107, 758, 682, 231, 299, 4,
667, 829, 465, 181, 543, 382, 368, 313, 274, 111, 114, 513, 675, 869,
302, 175, 180, 422, 39, 236, 748, 902, 819, 683, 907, 874, 491, 286,
292, 20, 786, 572, 582, 354, 673, 230])
batch: 8 tensor([861, 45, 7, 183, 620, 521, 518, 157, 882, 153, 651, 73, 105, 19,
823, 694, 152, 375, 377, 325, 229, 918, 537, 525, 366, 370, 352, 60,
87, 578, 459, 135, 396, 255, 905, 554, 238, 416, 584, 841, 929, 870,
9, 903, 858, 179, 640, 717, 200, 855, 889, 501, 548, 794, 631, 393,
264, 755, 102, 208, 677, 109, 269, 868])
batch: 9 tensor([893, 409, 740, 542, 188, 378, 818, 647, 309, 94, 330, 54, 161, 607,
278, 549, 126, 791, 539, 70, 362, 225, 784, 82, 605, 11, 530, 736,
351, 185, 186, 436, 254, 483, 941, 337, 410, 86, 790, 831, 49, 195,
526, 494, 488, 388, 381, 281, 147, 256, 78, 432, 108, 738, 133, 56,
464, 811, 132, 90, 342, 285, 547, 617])
batch: 10 tensor([193, 243, 613, 15, 55, 328, 622, 531, 705, 139, 569, 223, 917, 827,
507, 743, 503, 890, 544, 363, 141, 927, 361, 812, 475, 504, 665, 218,
392, 810, 145, 288, 77, 566, 490, 750, 506, 358, 305, 92, 144, 69,
120, 16, 35, 268, 830, 814, 546, 731, 867, 574, 591, 712, 916, 89,
747, 176, 166, 403, 519, 143, 735, 198])
batch: 11 tensor([ 67, 331, 314, 310, 660, 668, 440, 601, 287, 565, 373, 216, 590, 387,
720, 646, 780, 251, 196, 724, 777, 727, 489, 533, 339, 604, 892, 559,
779, 680, 350, 674, 698, 826, 457, 174, 247, 399, 383, 17, 637, 411,
374, 876, 880, 125, 719, 666, 26, 482, 284, 922, 353, 782, 346, 497,
347, 495, 115, 655, 199, 833, 804, 451])
batch: 12 tensor([836, 499, 508, 402, 848, 718, 602, 297, 312, 479, 842, 656, 376, 22,
189, 511, 768, 753, 767, 904, 689, 834, 616, 46, 527, 580, 561, 863,
529, 356, 589, 568, 146, 441, 700, 851, 662, 845, 433, 797, 423, 32,
597, 510, 641, 364, 250, 764, 793, 28, 134, 359, 478, 741, 564, 429,
609, 171, 835, 187, 612, 33, 711, 881])
batch: 13 tensor([936, 690, 737, 872, 866, 644, 36, 30, 136, 599, 469, 99, 894, 371,
725, 618, 369, 643, 3, 573, 534, 240, 444, 553, 25, 898, 210, 938,
695, 888, 593, 729, 445, 896, 431, 401, 34, 703, 883, 246, 485, 919,
103, 241, 934, 570, 276, 878, 603, 13, 611, 148, 341, 541, 821, 277,
685, 692, 263, 912, 0, 379, 81, 117])
batch: 14 tensor([824, 471, 47, 275, 556, 648, 450, 259, 686, 57, 822, 732, 670, 663,
813, 730, 317, 426, 420, 76, 796, 38, 887, 151, 911, 439, 774, 586,
270, 40, 447, 261, 943, 211, 424, 557, 307, 853, 659, 886, 763, 345,
213, 454, 398, 769, 311, 688, 112])
Now, if I use persistent_workers (i.e., persistent_workers = True
):
- num_workers = 0, where (
persistent_workers = False
) and num_workers >0 , where (persistent_workers = True
) agrees only in first epoch retrieved batches and result. - If I run the train process again from scratch, I can successfully reproduce the same results for num_workers > 0 for all epochs. and also for num_workers = 0.
- And, I can successfully resume training and reproduce the same results for each of them.
I wonder what is missing? Do we have to store the workers state and then restore again, and how?