Distributed launch utility unstable

The distributed launch utility seems like unstable in usage.
Executing the same program once with the following command

python -m torch.distributed.launch --nproc_per_node=3 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=62123 main.py

Works fine:

1.0, 0.05, 2.1814, 0.1697, 2.0053, 0.2154
1.0, 0.05, 2.1804, 0.1674, 1.9767, 0.2406
1.0, 0.05, 2.1823, 0.1703, 1.9799, 0.2352
2.0, 0.05, 2.1526, 0.1779, 2.1166, 0.1908
2.0, 0.05, 2.1562, 0.1812, 2.0868, 0.2076
2.0, 0.05, 2.1593, 0.1741, 2.0935, 0.192
3.0, 0.05, 1.9386, 0.2413, 1.8037, 0.3017
3.0, 0.05, 1.9319, 0.2473, 1.8041, 0.2903
3.0, 0.05, 1.9286, 0.2443, 1.815, 0.2939
4.0, 0.05, 1.7522, 0.3153, 1.828, 0.3131
4.0, 0.05, 1.7504, 0.3207, 1.7613, 0.3245

After the program is finished executing again the same command i.e., calling launch with the same arguments results in an error

File "/home/kirk/miniconda3/envs/torch/lib/python3.6/site-packages/torch/serialization.py", line 386, in load
    return _load(f, map_location, pickle_module, **pickle_load_args)
  File "/home/kirk/miniconda3/envs/torch/lib/python3.6/site-packages/torch/serialization.py", line 580, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
RuntimeError: storage has wrong size: expected 4333514340733757174 got 256

The error point to a broken checkpoint, not the distributed launch.
It seems you’ve saved some checkpoints in the previous runs without making sure only a single process (e.g. rank0) writes to the files.
This might yield to multiple processes writing to the same checkpoint file and thus breaking it.
Could this be the case?

Thanks for the reply. That could be the case but in all the examples I’ve seen using distributed launch didn’t show how to properly save the checkpoint. When I save the checkpoint I was just using torch.save. Should I be using something else?

How do I do that? Any pointers or examples where to look would be much appreciated!

@ptrblck So I did try your suggestion in the following way:

if torch.distributed.is_available() and torch.distributed.is_initialized():
   if os.environ['RANK'] == 0:
      torch.save(checkpoint)
else:
      torch.save(checkpoint)

But still I’m getting the same error:

I’m not sure if the RANK env variable is useful at this point. In the ImageNet example the args.rank variable is used. Could you try that?

@ptrblck Thanks for the pointer. My scenario is single node multi-gpu. Considering that case rank=0 and world_size=2. According to the imagenet example it says save the checkpoint if torch distributed is running or if torch distributed is running and the rank is equal to the num_gpus.

Why is not useful, my understanding is that it is set dynamically from the launch utility and it will contain whichever rank is currently running?

The environment variable is the legacy approach and is just set, if use_env was specified as seen here.
From the docs:

5. Another way to pass ``local_rank`` to the subprocesses via environment variable
``LOCAL_RANK``. This behavior is enabled when you launch the script with
``--use_env=True``. You must adjust the subprocess example above to replace
``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
will not pass ``--local_rank`` when you specify this flag.
.. warning::
    ``local_rank`` is NOT globally unique: it is only unique per process
    on a machine.  Thus, don't use it to decide if you should, e.g.,
    write to a networked filesystem.  See
    https://github.com/pytorch/pytorch/issues/12042 for an example of
    how things can go wrong if you don't do this correctly.
1 Like

Thanks for the clarifications, reading through the github issues it seems that:

  1. local_rank is actually the ID within a worker; multiple workers have a local_rank of 0 , so they’re probably trampling each other’s checkpoints.

  2. added a --global_rank command line argument as well. (solution)

  3. someone else comments, torch.distributed.launch sets up a RANK environment variable which can be used to detect if you are on the master process (with os.environ['RANK'] == '0' from python

  4. you can use torch.distributed.get_rank() to get the global rank. (I suppose this might be the most appropriate way to do it?)

1 Like