PyTorch DDP -- RuntimeError: Rank 10 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank

@fduwjj
Context: I was suggested to upgrade my NCCL to 2.17.1/2.18.1 however, I am not able to use pytorch:23.02-py3 because it requires GPUs with CC >=5 and my K80 GPUs have CC=7.
So, the newest nvidia container I could use was pytorch:22.12-py3 that installs NCCL 2.14.3. I am also not sure if the problem below is because of NCCL version or another problem in PyTorch DDP.

total 0
NCCL version is:  (2, 14, 3)
System information: Linux #36~20.04.1-Ubuntu SMP Tue Dec 6 17:00:26 UTC 2022
Python version: 3.8.10
MLflow version: 2.3.2
MLflow module location: /usr/local/lib/python3.8/dist-packages/mlflow/__init__.py
Tracking URI: URI
Registry URI: URI
MLflow environment variables: 
  MLFLOW_DISABLE_ENV_MANAGER_CONDA_WARNING: True
  MLFLOW_EXPERIMENT_ID: 03bf0c01-34b3-4b8f-9713-b744f0350832
  MLFLOW_EXPERIMENT_NAME: dev_CIFAR10_DDP_train_test2
  MLFLOW_RUN_ID:
 
MLflow dependencies: 
  Flask: 2.3.2
  Jinja2: 3.1.2
  alembic: 1.11.1
  click: 8.1.3
  cloudpickle: 2.2.0
  databricks-cli: 0.17.7
  docker: 6.1.2
  entrypoints: 0.4
  gitpython: 3.1.31
  gunicorn: 20.1.0
  importlib-metadata: 5.1.0
  markdown: 3.4.1
  matplotlib: 3.5.2
  numpy: 1.22.2
  packaging: 22.0
  pandas: 1.5.2
  protobuf: 3.20.1
  pyarrow: 9.0.0
  pytz: 2022.6
  pyyaml: 6.0
  querystring-parser: 1.2.4
  requests: 2.28.1
  scikit-learn: 0.24.2
  scipy: 1.6.3
  sqlalchemy: 2.0.15
  sqlparse: 0.4.4
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO cudaDriverVersion 11040
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO Bootstrap : Using eth0:10.0.0.7<0>
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO NET/Plugin: Failed to find ncclNetPlugin_v6 symbol.
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin (v5)
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v6 symbol.
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO NET/Plugin: Loaded coll plugin SHARP (v5)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO P2P plugin IBext
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NET/IB : No device found.
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NET/Socket : Using [0]eth0:10.0.0.7<0>
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Using network Socket
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0001-0000-3130-444531303244/pci0001:00/0001:00:00.0/../max_link_speed, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0001-0000-3130-444531303244/pci0001:00/0001:00:00.0/../max_link_width, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0002-0000-3130-444531303244/pci0002:00/0002:00:00.0/../max_link_speed, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0002-0000-3130-444531303244/pci0002:00/0002:00:00.0/../max_link_width, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0003-0000-3130-444531303244/pci0003:00/0003:00:00.0/../max_link_speed, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0003-0000-3130-444531303244/pci0003:00/0003:00:00.0/../max_link_width, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0004-0000-3130-444531303244/pci0004:00/0004:00:00.0/../max_link_speed, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0004-0000-3130-444531303244/pci0004:00/0004:00:00.0/../max_link_width, ignoring
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Topology detection: network path /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/000d3a00-71d4-000d-3a00-71d4000d3a00 is not a PCI device (vmbus). Attaching to first CPU
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO === System : maxBw 5.0 totalBw 12.0 ===
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO CPU/0 (1/1/1)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO + PCI[5000.0] - NIC/0
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO                 + NET[5.0] - NET/0 (0/0/5.000000)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO + PCI[12.0] - GPU/100000 (8)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO + PCI[12.0] - GPU/200000 (9)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO + PCI[12.0] - GPU/300000 (10)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO + PCI[12.0] - GPU/400000 (11)
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO ==========================================
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO GPU/100000 :GPU/100000 (0/5000.000000/LOC) GPU/200000 (2/12.000000/PHB) GPU/300000 (2/12.000000/PHB) GPU/400000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB) 
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO GPU/200000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (0/5000.000000/LOC) GPU/300000 (2/12.000000/PHB) GPU/400000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB) 
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO GPU/300000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (2/12.000000/PHB) GPU/300000 (0/5000.000000/LOC) GPU/400000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB) 
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO GPU/400000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (2/12.000000/PHB) GPU/300000 (2/12.000000/PHB) GPU/400000 (0/5000.000000/LOC) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB) 
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NET/0 :GPU/100000 (3/5.000000/PHB) GPU/200000 (3/5.000000/PHB) GPU/300000 (3/5.000000/PHB) GPU/400000 (3/5.000000/PHB) CPU/0 (2/5.000000/PHB) NET/0 (0/5000.000000/LOC) 
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Setting affinity for GPU 2 to 0fff
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Pattern 4, crossNic 0, nChannels 1, bw 5.000000/5.000000, type PHB/PHB, sameChannels 1
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO  0 : NET/0 GPU/8 GPU/9 GPU/10 GPU/11 NET/0
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Pattern 1, crossNic 0, nChannels 1, bw 6.000000/5.000000, type PHB/PHB, sameChannels 1
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO  0 : NET/0 GPU/8 GPU/9 GPU/10 GPU/11 NET/0
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Pattern 3, crossNic 0, nChannels 0, bw 0.000000/0.000000, type NVL/PIX, sameChannels 1
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Ring 00 : 9 -> 10 -> 11
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Ring 01 : 9 -> 10 -> 11
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Trees [0] 11/-1/-1->10->9 [1] 11/-1/-1->10->9
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Channel 00 : 10[300000] -> 11[400000] via SHM/direct/direct
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Channel 01 : 10[300000] -> 11[400000] via SHM/direct/direct
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Connected all rings
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Channel 00 : 10[300000] -> 9[200000] via SHM/direct/direct
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Channel 01 : 10[300000] -> 9[200000] via SHM/direct/direct
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO Connected all trees
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 512 | 512
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO 2 coll channels, 2 p2p channels,World size: 16
local rank is 2 and world rank is 10
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data_10/cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data_10/cifar-10-python.tar.gz

  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 458752/170498071 [00:00<00:37, 4563814.69it/s]
  3%|▎         | 5144576/170498071 [00:00<00:05, 29383762.59it/s]
  6%|▋         | 10911744/170498071 [00:00<00:03, 42197352.68it/s]
 10%|▉         | 16711680/170498071 [00:00<00:03, 48285422.28it/s]
 13%|█▎        | 22511616/170498071 [00:00<00:02, 51745710.13it/s]
 17%|█▋        | 28278784/170498071 [00:00<00:02, 53649710.56it/s]
 20%|██        | 34111488/170498071 [00:00<00:02, 55051428.62it/s]
 23%|██▎       | 39616512/170498071 [00:00<00:02, 54752953.84it/s]
 27%|██▋       | 45416448/170498071 [00:00<00:02, 55654936.27it/s]
 31%|███▏      | 53313536/170498071 [00:01<00:01, 62765961.80it/s]
 38%|███▊      | 64880640/170498071 [00:01<00:01, 78870849.68it/s]
 45%|████▍     | 76414976/170498071 [00:01<00:01, 89917128.56it/s]
 51%|█████▏    | 87621632/170498071 [00:01<00:00, 96543328.70it/s]
 58%|█████▊    | 99188736/170498071 [00:01<00:00, 102260225.54it/s]
 65%|██████▍   | 110428160/170498071 [00:01<00:00, 105277849.46it/s]
 72%|███████▏  | 121962496/170498071 [00:01<00:00, 108244947.78it/s]
 78%|███████▊  | 133529600/170498071 [00:01<00:00, 110453582.54it/s]
 85%|████████▌ | 145162240/170498071 [00:01<00:00, 112113339.44it/s]
 92%|█████████▏| 156696576/170498071 [00:01<00:00, 113080563.55it/s]
 99%|█████████▊| 168329216/170498071 [00:02<00:00, 114013376.61it/s]
100%|██████████| 170498071/170498071 [00:02<00:00, 84146552.80it/s] 
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
Extracting ./data_10/cifar-10-python.tar.gz to ./data_10
Files already downloaded and verified

  0%|          | 0.00/97.8M [00:00<?, ?B/s]
 20%|█▉        | 19.6M/97.8M [00:00<00:00, 205MB/s]
 46%|████▌     | 44.9M/97.8M [00:00<00:00, 241MB/s]
 75%|███████▌  | 73.4M/97.8M [00:00<00:00, 267MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 265MB/s] 2 p2p channels per peer
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO NCCL_P2P_PXN_LEVEL set by environment to 0.
34d0f284fac94434817d429e96547367000003:44:218 [2] NCCL INFO comm 0x2c1e8620 rank 10 nranks 16 cudaDev 2 busId 300000 - Init COMPLETE
[E ProcessGroupGloo.cpp:137] Rank 10 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.

Traceback (most recent call last):
  File "train.py", line 107, in <module>
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 655, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/utils.py", line 112, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
RuntimeError: Rank 10 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
 Original exception: 
[../third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:133] Timed out waiting 1800000ms for send operation to complete
34d0f284fac94434817d429e96547367000003:44:220 [2] NCCL INFO [Service thread] Connection closed by localRank 2
34d0f284fac94434817d429e96547367000003:44:44 [2] NCCL INFO comm 0x2c1e8620 rank 10 nranks 16 cudaDev 2 busId 300000 - Abort COMPLETE

Here’s my Dockerfile:

# check release notes https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
# nvidia containers https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-04.html#rel-23-04
# FROM nvcr.io/nvidia/pytorch:22.04-py3
#FROM nvcr.io/nvidia/pytorch:23.02-py3 #requires GPUs with compute capability of 5+
FROM nvcr.io/nvidia/pytorch:22.12-py3

##############################################################################
# NCCL TESTS
##############################################################################
ENV NCCL_TESTS_TAG=v2.11.0

# NOTE: adding gencodes to support K80, M60, V100, A100
RUN mkdir /tmp/nccltests && \
    cd /tmp/nccltests && \
    git clone -b ${NCCL_TESTS_TAG} https://github.com/NVIDIA/nccl-tests.git && \
    cd nccl-tests && \
    make \
    MPI=1 MPI_HOME=/opt/hpcx/ompi \
    NVCC_GENCODE="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80" \
    CUDA_HOME=/usr/local/cuda && \
    cp ./build/* /usr/local/bin && \
    rm -rf /tmp/nccltests

# Install dependencies missing in this container
# NOTE: container already has matplotlib==3.5.1 tqdm==4.62.0
COPY requirements.txt ./
RUN pip install -r requirements.txt


# add ndv4-topo.xml
RUN mkdir /opt/microsoft/
ADD ./ndv4-topo.xml /opt/microsoft

# to use on A100, enable env var below in your job
# ENV NCCL_TOPO_FILE="/opt/microsoft/ndv4-topo.xml"

# adjusts the level of info from NCCL tests
ENV NCCL_DEBUG="INFO"
ENV NCCL_DEBUG_SUBSYS="GRAPH,INIT,ENV"

# Relaxed Ordering can greatly help the performance of Infiniband networks in virtualized environments.
# ENV NCCL_IB_PCI_RELAXED_ORDERING="1"
# suggested to set ENV NCCL_IB_PCI_RELAXED_ORDERING to 0 for NCCL 2.18.1
ENV NCCL_IB_PCI_RELAXED_ORDERING="0" 
ENV CUDA_DEVICE_ORDER="PCI_BUS_ID"
ENV NCCL_SOCKET_IFNAME="eth0"
ENV NCCL_P2P_PXN_LEVEL="0"
# ENV NCCL_SOCKET_IFNAME='lo'
ENV NCCL_IB_DISABLE="1"

and here’s my requirements.txt:

# torch and torchvision compatibility matrix https://github.com/pytorch/pytorch/wiki/PyTorch-Versions
torch==1.13.0
torchvision==0.14.0

mlflow==2.3.2
azureml-mlflow==1.50.0
matplotlib==3.5.2
tqdm==4.64.0
psutil==5.9.0

# for unit testing
pytest==7.1.2

Here’s train.py code:

import time
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import mlflow
import os
import datetime
import gc
import configparser
import logging
import argparse

from PIL import Image
from torch.distributed.elastic.multiprocessing.errors import record #TODO create main and use at sign record later

import ssl
ssl._create_default_https_context = ssl._create_unverified_context


start_time = time.time()

# torch.cuda.empty_cache()
# gc.collect()


torch.backends.cudnn.benchmark=False
torch.backends.cudnn.deterministic=True


print("NCCL version is: ", torch.cuda.nccl.version())


# MLflow >= 2.0
mlflow.doctor()

# Set the seed for reproducibility
torch.manual_seed(42)

# Set up the data loading parameters
batch_size = 128
num_epochs = 10
num_workers = 4
pin_memory = True

# Get the world size and rank to determine the process group
world_size = int(os.environ['WORLD_SIZE'])
world_rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])

print("World size:", world_size)
print("local rank is {} and world rank is {}".format(local_rank, world_rank))

is_distributed = world_size > 1

if is_distributed:
    batch_size = batch_size // world_size
    batch_size = max(batch_size, 1)

# Set the backend to NCCL for distributed training
dist.init_process_group(backend="nccl",
                        init_method="env://",
                        world_size=world_size,
                        rank=world_rank)

# Set the device to the current local rank
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)

dist.barrier()

# Define the transforms for the dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# Load the CIFAR-10 dataset

data_root = './data_' + str(world_rank)
train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=world_rank, shuffle=True) if is_distributed else None
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=num_workers, pin_memory=pin_memory, sampler=train_sampler)

test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

# Define the ResNet50 model
model = torchvision.models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# Move the model to the GPU
model = model.to(device)

# Wrap the model with DistributedDataParallel
if is_distributed:
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model for the specified number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    train_sampler.set_epoch(epoch) ### why is this line necessary??
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss))
    if world_rank == 0:
        # Log the loss and running loss as MLFlow metrics
        mlflow.log_metric("loss", loss.item())
        mlflow.log_metric("running loss", running_loss)

dist.barrier()
# Save the trained model
if world_rank == 0:
    checkpoints_path = "train_checkpoints"
    os.makedirs(checkpoints_path, exist_ok=True)
    torch.save(model.state_dict(), '{}/{}-{}.pth'.format(checkpoints_path, 'resnet50_cifar10', world_rank))
    mlflow.pytorch.log_model(model, "resnet50_cifar10_{}.pth".format(world_rank))
    # mlflow.log_artifact('{}/{}-{}.pth'.format(checkpoints_path, 'resnet50_cifar10', world_rank), artifact_path="model_state_dict")

    # Evaluate the model on the test set and save inference on 6 random images
    correct = 0
    total = 0
    with torch.no_grad():
        fig, axs = plt.subplots(2, 3, figsize=(8, 6), dpi=100)
        axs = axs.flatten()
        count = 0
        for data in test_loader:
            if count == 6:
                break
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Save the inference on the 6 random images
            if count < 6:
                image = np.transpose(inputs[0].cpu().numpy(), (1, 2, 0))
                confidence = torch.softmax(outputs, dim=1)[0][predicted[0]].cpu().numpy()
                class_name = test_dataset.classes[predicted[0]]
                axs[count].imshow(image)
                axs[count].set_title(f'Class: {class_name}\nConfidence: {confidence:.2f}')
                axs[count].axis('off')
                count += 1
            
            
    test_accuracy = 100 * correct / total
    print('Test accuracy: %.2f %%' % test_accuracy)

# # Average the test accuracy across all processes

# correct = torch.tensor(correct, dtype=torch.int8)
# correct = correct.to(device)
# torch.distributed.all_reduce(correct, op=torch.distributed.ReduceOp.SUM)
# total = torch.tensor(total, dtype=torch.torch.int8)
# total = total.to(device)
# torch.distributed.all_reduce(total, op=torch.distributed.ReduceOp.SUM)
# test_accuracy = 100 * correct / total
# test_accuracy /= world_size

# print('Test accuracy: %.2f %%' % test_accuracy)

# Save the plot with the 6 random images and their predicted classes and prediction confidence
test_img_file_name = 'test_images_' + str(world_rank) + '.png'
plt.savefig(test_img_file_name)

# Log the test accuracy and elapsed time to MLflow
if world_rank == 0:
    mlflow.log_metric("test accuracy", test_accuracy)

end_time = time.time()
elapsed_time = end_time - start_time
print('Elapsed time: ', elapsed_time)
if world_rank == 0:
    mlflow.log_metric("elapsed time", elapsed_time)

if world_rank == 0:
    # Save the plot with the 6 random images and their predicted classes and prediction confidence as an artifact in MLflow
    image = Image.open(test_img_file_name)
    image = image.convert('RGBA')
    image_buffer = np.array(image)
    image_buffer = image_buffer[:, :, [2, 1, 0, 3]]
    image_buffer = np.ascontiguousarray(image_buffer)
    artifact_file_name = "inference_on_test_images_" + str(world_rank) + ".png"
    mlflow.log_image(image_buffer, artifact_file=artifact_file_name)

# End the MLflow run
if mlflow.active_run():
    mlflow.end_run()

dist.destroy_process_group()


However when I go to process_00 log, I don’t see an error so it’s very hard to debug.