How to resume training when using DataLoader persistent_workers

First, I followed the steps in this discussion, so that the results are reproducible with different num_workers.

using exact steps as here, which has worker_init_fn as follows:

def _init_fn(worker_id):
    # Not necessary since we will reinitialize the internal generator state of the worker at EVERY SAMPLE!
    pass

The seed function at the start:

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

Save a checkpoint for every epoch including the random generator states

checkpoint = {
‘cpu_rng_state’: torch.get_rng_state(),
‘gpu_rng_state’: torch.cuda.get_rng_state(),
‘gpu_rng_state_all’:torch.cuda.get_rng_state_all(),
‘numpy_rng_state’: np.random.get_state(),
‘py_rng_state’: random.getstate()
}

And to resume training from specific epoch, I restore the random generator states from saved checkpoint of that epoch.

Now if I don’t use persistent_workers (i.e., persistent_workers = False):

  1. If I run the train process again from scratch, I can successfully reproduce the same results with num_workers >=0 for all epochs.
  2. And, I can successfully resume training and reproduce the same results for all epochs.

For example, the batches for epoch 6 are as follows, and if I used the checkpoint of epoch 5, I can successfully regenerate same batches for epoch 6 again:

batch: 0 tensor([502, 437, 150, 787, 320, 303, 154, 939,  68, 899, 657, 716,  85, 484,
        334, 443, 734,  91, 473, 324, 733, 235,  75, 885, 739, 249, 704, 588,
        563, 825, 516, 820, 110, 242, 266, 100, 384, 638, 212, 709, 785, 807,
        295, 514, 460, 702, 207, 458, 415, 304, 653,  62, 104, 756, 348, 486,
        221, 621, 754, 500, 456, 871, 333, 421])
batch: 1 tensor([913, 595, 386, 122,  64, 191, 428, 634, 679,  12, 206, 594, 214, 633,
        849, 468,  18, 699,  24,  21, 248, 190, 884, 343, 940, 239, 124, 627,
         43, 908, 536, 329, 654, 413, 332, 253, 390, 843, 316, 280, 385, 159,
        234, 414, 137, 722, 710, 900, 222, 480, 344, 795, 664, 775, 335, 707,
        746, 435, 326, 349, 815, 895, 289, 237])
batch: 2 tensor([944, 696, 678, 545, 455,  42, 891,   5, 463, 452,  59, 771, 128, 914,
        751, 854, 101, 158, 592, 498, 524,  27, 394, 600, 765, 446, 391, 691,
        625, 232, 923, 850, 184, 635,  95, 245, 540,  71,  29, 372, 649, 282,
        360,  23, 290, 828,  50, 802, 194, 220, 658, 227, 142, 178, 140,  66,
        596, 438,  10, 300, 697, 723, 816, 505])
batch: 3 tensor([389,  80, 862, 407, 598, 323, 430, 629, 669, 671, 844, 267, 681, 301,
        201, 713, 168, 752, 448, 434, 856,  48, 467, 809, 155, 138,  37, 931,
        781, 367,  44,  97, 271, 336,  96, 800, 476, 630, 404, 472, 770, 177,
        205, 400, 766,  51, 322, 614, 742, 661, 520, 538, 149, 492, 162, 879,
        575,  83, 924, 466, 925, 577, 706, 761])
batch: 4 tensor([509, 757, 318, 942, 327, 726,  53,  52, 417, 419, 792, 585, 909, 684,
        873,  72, 799, 130, 935, 687, 283, 576, 380, 859, 228, 258, 832, 203,
        427, 308, 405, 583, 915, 273, 406, 778, 624, 676, 321, 571, 515, 209,
        608, 296, 172, 623, 418, 535, 487, 773, 555, 294, 772, 257, 260, 474,
        744, 202, 560, 449, 721, 338, 462, 493])
batch: 5 tensor([587, 477, 930, 776, 928, 805, 113, 173, 783, 226, 156, 652, 672, 262,
        808,  88, 632, 408, 933, 864,  65, 481, 745, 453, 846, 837, 636, 760,
        397, 118, 817, 169, 798,  14, 233, 806, 167, 517, 708,  93, 701, 877,
        523,  41, 838, 693, 642, 645, 131, 340, 357,  61,   6, 920, 801, 215,
        852, 789, 298,  84, 762, 219, 244,  98])
batch: 6 tensor([567, 628, 910, 615, 160, 840, 606, 728, 532, 558,   1, 528, 170, 865,
        610, 875, 442, 412, 788, 119, 164, 204, 165, 650, 860, 839, 192, 395,
          8, 932, 579, 129, 715, 127, 355, 626, 512,   2, 265, 293, 365,  31,
        857, 279, 639, 847, 163, 714, 803, 106,  63, 123, 759, 291,  74, 224,
        116,  79, 562, 619, 197, 272, 315, 217])
batch: 7 tensor([496, 461, 252, 937, 182, 121, 906, 749, 926, 921, 522, 470,  58, 551,
        897, 550, 552, 901, 319, 581, 306, 425, 107, 758, 682, 231, 299,   4,
        667, 829, 465, 181, 543, 382, 368, 313, 274, 111, 114, 513, 675, 869,
        302, 175, 180, 422,  39, 236, 748, 902, 819, 683, 907, 874, 491, 286,
        292,  20, 786, 572, 582, 354, 673, 230])
batch: 8 tensor([861,  45,   7, 183, 620, 521, 518, 157, 882, 153, 651,  73, 105,  19,
        823, 694, 152, 375, 377, 325, 229, 918, 537, 525, 366, 370, 352,  60,
         87, 578, 459, 135, 396, 255, 905, 554, 238, 416, 584, 841, 929, 870,
          9, 903, 858, 179, 640, 717, 200, 855, 889, 501, 548, 794, 631, 393,
        264, 755, 102, 208, 677, 109, 269, 868])
batch: 9 tensor([893, 409, 740, 542, 188, 378, 818, 647, 309,  94, 330,  54, 161, 607,
        278, 549, 126, 791, 539,  70, 362, 225, 784,  82, 605,  11, 530, 736,
        351, 185, 186, 436, 254, 483, 941, 337, 410,  86, 790, 831,  49, 195,
        526, 494, 488, 388, 381, 281, 147, 256,  78, 432, 108, 738, 133,  56,
        464, 811, 132,  90, 342, 285, 547, 617])
batch: 10 tensor([193, 243, 613,  15,  55, 328, 622, 531, 705, 139, 569, 223, 917, 827,
        507, 743, 503, 890, 544, 363, 141, 927, 361, 812, 475, 504, 665, 218,
        392, 810, 145, 288,  77, 566, 490, 750, 506, 358, 305,  92, 144,  69,
        120,  16,  35, 268, 830, 814, 546, 731, 867, 574, 591, 712, 916,  89,
        747, 176, 166, 403, 519, 143, 735, 198])
batch: 11 tensor([ 67, 331, 314, 310, 660, 668, 440, 601, 287, 565, 373, 216, 590, 387,
        720, 646, 780, 251, 196, 724, 777, 727, 489, 533, 339, 604, 892, 559,
        779, 680, 350, 674, 698, 826, 457, 174, 247, 399, 383,  17, 637, 411,
        374, 876, 880, 125, 719, 666,  26, 482, 284, 922, 353, 782, 346, 497,
        347, 495, 115, 655, 199, 833, 804, 451])
batch: 12 tensor([836, 499, 508, 402, 848, 718, 602, 297, 312, 479, 842, 656, 376,  22,
        189, 511, 768, 753, 767, 904, 689, 834, 616,  46, 527, 580, 561, 863,
        529, 356, 589, 568, 146, 441, 700, 851, 662, 845, 433, 797, 423,  32,
        597, 510, 641, 364, 250, 764, 793,  28, 134, 359, 478, 741, 564, 429,
        609, 171, 835, 187, 612,  33, 711, 881])
batch: 13 tensor([936, 690, 737, 872, 866, 644,  36,  30, 136, 599, 469,  99, 894, 371,
        725, 618, 369, 643,   3, 573, 534, 240, 444, 553,  25, 898, 210, 938,
        695, 888, 593, 729, 445, 896, 431, 401,  34, 703, 883, 246, 485, 919,
        103, 241, 934, 570, 276, 878, 603,  13, 611, 148, 341, 541, 821, 277,
        685, 692, 263, 912,   0, 379,  81, 117])
batch: 14 tensor([824, 471,  47, 275, 556, 648, 450, 259, 686,  57, 822, 732, 670, 663,
        813, 730, 317, 426, 420,  76, 796,  38, 887, 151, 911, 439, 774, 586,
        270,  40, 447, 261, 943, 211, 424, 557, 307, 853, 659, 886, 763, 345,
        213, 454, 398, 769, 311, 688, 112])

Now, if I use persistent_workers (i.e., persistent_workers = True):

  1. num_workers = 0, where (persistent_workers = False) and num_workers >0 , where (persistent_workers = True) agrees only in first epoch retrieved batches and result.
  2. If I run the train process again from scratch, I can successfully reproduce the same results for num_workers > 0 for all epochs. and also for num_workers = 0.
  3. And, I can successfully resume training and reproduce the same results for each of them.

I wonder what is missing? Do we have to store the workers state and then restore again, and how?

Solution: The problem described above, arises when persistent_workers = True and shuffle = True even if I seed randomness in def _init_fn(worker_id).

A solution that works in order to use persistent_workers = True and obtain a reproducible results for num_workers >= 0 whether starting train from scratch or in resume case:

  1. Use a sampler (e.g., torch.utils.data.RandomSampler for your dataset), and assign it to the dataloader sampler.
  2. Since sampler option is mutually exclusive with shuffle, set shuffle = False in the dataloader, samples will be shuffled by the torch.utils.data.RandomSampler.