PyTorch 2 DistributedDataParallel

I am using a NVIDIA PyTorch docker from Facebook. It has PyTorch 2 and NCCL 2.17.1. Do you know how I can fix this error? I am doing DDP in an Azure cluster with 2 nodes each having 2 M60 GPU with compute capability of 5.2.

total 0
CPython
3.8.10
uname_result(system='Linux', node='7aa6ec4037f840a298e0306dd926ea9f000001', release='5.15.0-1029-azure', version='#36~20.04.1-Ubuntu SMP Tue Dec 6 17:00:26 UTC 2022', machine='x86_64', processor='x86_64')
NCCL version is:  (2, 17, 1)
System information: Linux #36~20.04.1-Ubuntu SMP Tue Dec 6 17:00:26 UTC 2022
Python version: 3.8.10
MLflow version: 2.3.2
MLflow module location: /usr/local/lib/python3.8/dist-packages/mlflow/__init__.py
Tracking URI: URI
Registry URI: URI
MLflow environment variables:
  MLFLOW_DISABLE_ENV_MANAGER_CONDA_WARNING: True
  MLFLOW_EXPERIMENT_ID: 97cdf0ad-6496-41c6-92a3-609b2474fa29
  MLFLOW_EXPERIMENT_NAME: dev_CIFAR10_DDP_train_test2
  MLFLOW_RUN_ID: 63885446-8a41-40e4-9dd4-fd0867a260ba
  MLFLOW_TRACKING_TOKEN: token
  MLFLOW_TRACKING_URI: URI
MLflow dependencies:
  Flask: 2.3.2
  Jinja2: 3.1.2
  alembic: 1.11.1
  click: 8.1.3
  cloudpickle: 2.2.1
  databricks-cli: 0.17.7
  docker: 6.1.2
  entrypoints: 0.4
  gitpython: 3.1.31
  gunicorn: 20.1.0
  importlib-metadata: 6.3.0
  markdown: 3.4.3
  matplotlib: 3.7.1
  numpy: 1.22.2
  packaging: 23.0
  pandas: 1.5.2
  protobuf: 3.20.3
  pyarrow: 10.0.1.dev0+ga6eabc2b.d20230410
  pytz: 2023.3
  pyyaml: 6.0
  querystring-parser: 1.2.4
  requests: 2.28.2
  scikit-learn: 1.2.0
  scipy: 1.10.1
  sqlalchemy: 2.0.15
  sqlparse: 0.4.4
INFO:__main__:os.getpid() is 25 and initializing process group with {'MASTER_ADDR': '10.0.0.5', 'MASTER_PORT': '6105', 'LOCAL_RANK': '1', 'RANK': '1', 'WORLD_SIZE': '4'}
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 1
INFO:torch.distributed.distributed_c10d:Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO cudaDriverVersion 12010
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO Bootstrap : Using eth0:10.0.0.5<0>
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO NET/Plugin: Failed to find ncclNetPlugin_v6 symbol.
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin (v5)
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v6 symbol.
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO NET/Plugin: Loaded coll plugin SHARP (v5)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO P2P plugin IBext
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO NET/IB : No device found.
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO NET/Socket : Using [0]eth0:10.0.0.5<0>
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Using network Socket
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0001-0000-3130-444531334632/pci0001:00/0001:00:00.0/../max_link_speed, ignoring
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0001-0000-3130-444531334632/pci0001:00/0001:00:00.0/../max_link_width, ignoring
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0002-0000-3130-444531334632/pci0002:00/0002:00:00.0/../max_link_speed, ignoring
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0002-0000-3130-444531334632/pci0002:00/0002:00:00.0/../max_link_width, ignoring
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Topology detection: network path /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/000d3a4f-ce52-000d-3a4f-ce52000d3a4f is not a PCI device (vmbus). Attaching to first CPU
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO === System : maxBw 5.0 totalBw 12.0 ===
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO CPU/0 (1/1/1)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO + PCI[5000.0] - NIC/0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO                 + NET[5.0] - NET/0 (0/0/5.000000)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO + PCI[12.0] - GPU/100000 (0)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO + PCI[12.0] - GPU/200000 (1)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO ==========================================
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO GPU/100000 :GPU/100000 (0/5000.000000/LOC) GPU/200000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO GPU/200000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (0/5000.000000/LOC) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
World size: 4
local rank is 1 and world rank is 1
PyTorch version is 2.1.0a0+fe05266 and torchvision version is 0.15.0a0

  0%|          | 0.00/97.8M [00:00<?, ?B/s]
  8%|▊         | 7.94M/97.8M [00:00<00:01, 82.9MB/s]
 38%|███▊      | 37.5M/97.8M [00:00<00:00, 216MB/s]
 64%|██████▎   | 62.2M/97.8M [00:00<00:00, 236MB/s]
 91%|█████████ | 89.0M/97.8M [00:00<00:00, 253MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 236MB/s]
] NCCL INFO NET/0 :GPU/100000 (3/5.000000/PHB) GPU/200000 (3/5.000000/PHB) CPU/0 (2/5.000000/PHB) NET/0 (0/5000.000000/LOC)
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Setting affinity for GPU 1 to 0fff
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Pattern 4, crossNic 0, nChannels 1, bw 5.000000/5.000000, type PHB/PHB, sameChannels 1
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO  0 : NET/0 GPU/0 GPU/1 NET/0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Pattern 1, crossNic 0, nChannels 1, bw 6.000000/5.000000, type PHB/PHB, sameChannels 1
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO  0 : NET/0 GPU/0 GPU/1 NET/0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Pattern 3, crossNic 0, nChannels 0, bw 0.000000/0.000000, type NVL/PIX, sameChannels 1
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Tree 0 : 0 -> 1 -> -1/-1/-1
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Tree 1 : 0 -> 1 -> -1/-1/-1
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Ring 00 : 0 -> 1 -> 2
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Ring 01 : 0 -> 1 -> 2
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO P2P Chunksize set to 131072
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Channel 00/0 : 1[200000] -> 2[100000] [send] via NET/Socket/0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Channel 01/0 : 1[200000] -> 2[100000] [send] via NET/Socket/0
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Connected all rings
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Channel 00 : 1[200000] -> 0[100000] via SHM/direct/direct
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Channel 01 : 1[200000] -> 0[100000] via SHM/direct/direct
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO Connected all trees
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO NCCL_P2P_PXN_LEVEL set by environment to 0.
7aa6ec4037f840a298e0306dd926ea9f000001:25:127 [1] NCCL INFO comm 0x9096940 rank 1 nranks 4 cudaDev 1 busId 200000 commId 0x3346fdbb92a0dc2a - Init COMPLETE
Traceback (most recent call last):
  File "train.py", line 163, in <module>
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 631, in __init__
    current_cga = default_pg_nccl.options.config.cga_cluster_size
AttributeError: 'torch._C._distributed_c10d._ProcessGroupWrapper' object has no attribute 'options'
7aa6ec4037f840a298e0306dd926ea9f000001:25:129 [1] NCCL INFO [Service thread] Connection closed by localRank 1
7aa6ec4037f840a298e0306dd926ea9f000001:25:25 [1] NCCL INFO comm 0x9096940 rank 1 nranks 4 cudaDev 1 busId 200000 - Abort COMPLETE

Here’s my Dockerfile:

# check release notes https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
# nvidia containers https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-04.html#rel-23-04
# FROM nvcr.io/nvidia/pytorch:22.04-py3
#FROM nvcr.io/nvidia/pytorch:23.02-py3 #requires GPUs with compute capability of 5+
# FROM nvcr.io/nvidia/pytorch:22.12-py3
# FROM nvcr.io/nvidia/pytorch:21.06-py3
FROM nvcr.io/nvidia/pytorch:23.04-py3


##############################################################################
# NCCL TESTS
##############################################################################
ENV NCCL_TESTS_TAG=v2.11.0

# NOTE: adding gencodes to support K80, M60, V100, A100
RUN mkdir /tmp/nccltests && \
    cd /tmp/nccltests && \
    git clone -b ${NCCL_TESTS_TAG} https://github.com/NVIDIA/nccl-tests.git && \
    cd nccl-tests && \
    make \
    MPI=1 MPI_HOME=/opt/hpcx/ompi \
    NVCC_GENCODE="-gencode=arch=compute_52,code=sm_52" \
    CUDA_HOME=/usr/local/cuda && \
    cp ./build/* /usr/local/bin && \
    rm -rf /tmp/nccltests

# Install dependencies missing in this container
# NOTE: container already has matplotlib==3.5.1 tqdm==4.62.0
COPY requirements.txt ./
RUN pip install -r requirements.txt


# add ndv4-topo.xml
RUN mkdir /opt/microsoft/
ADD ./ndv4-topo.xml /opt/microsoft

# to use on A100, enable env var below in your job
# ENV NCCL_TOPO_FILE="/opt/microsoft/ndv4-topo.xml"

# adjusts the level of info from NCCL tests
ENV NCCL_DEBUG="INFO"
ENV NCCL_DEBUG_SUBSYS="GRAPH,INIT,ENV"

# Relaxed Ordering can greatly help the performance of Infiniband networks in virtualized environments.
# ENV NCCL_IB_PCI_RELAXED_ORDERING="1"
# suggested to set ENV NCCL_IB_PCI_RELAXED_ORDERING to 0 for NCCL 2.18.1
ENV NCCL_IB_PCI_RELAXED_ORDERING="0" 
ENV CUDA_DEVICE_ORDER="PCI_BUS_ID"
ENV NCCL_SOCKET_IFNAME="eth0"
ENV NCCL_P2P_PXN_LEVEL="0"
# ENV NCCL_P2P_DISABLE="1"
# ENV NCCL_SOCKET_IFNAME='lo'
ENV NCCL_IB_DISABLE="1"

Here’s my requirements.txt:

# torch and torchvision compatibility matrix https://github.com/pytorch/pytorch/wiki/PyTorch-Versions
mlflow==2.3.2
azureml-mlflow==1.50.0
psutil==5.9.0

# for unit testing
pytest==7.1.2

Here’s the train.py code:

import time
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import mlflow
import os
import datetime
import gc
import configparser
import logging
import argparse

from datetime import datetime, timedelta

from PIL import Image
from torch.distributed.elastic.multiprocessing.errors import record #TODO create main and use at sign record later

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import platform


print(platform.python_implementation())
print(platform.python_version())
print(platform.uname())

start_time = time.time()


conf_parser = argparse.ArgumentParser(
    description=__doc__, # printed with -h/--help
    # Don't mess with format of description
    formatter_class=argparse.RawDescriptionHelpFormatter,
    # Turn off help, so we print all options in response to -h
    add_help=False
    )

parser = argparse.ArgumentParser()

parser.add_argument('--data',  
    default = "", 
    help='path to training data')

parser.add_argument('--checkpoints', 
    type=str,
    default=None,
    required=False, 
    help='Path to read/write checkpoints')

# Read the config but do not overwrite the args written 
args, remaining_argv = conf_parser.parse_known_args()
defaults = { "option":"default" }

opt = parser.parse_args(remaining_argv)


# torch.cuda.empty_cache()
# gc.collect()

torch.backends.cudnn.benchmark=False #TODO: is this needed?
torch.backends.cudnn.deterministic=True


print("NCCL version is: ", torch.cuda.nccl.version())


# MLflow >= 2.0
mlflow.doctor()

# Set the seed for reproducibility
torch.manual_seed(42)

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("requests").setLevel(logging.DEBUG)
logging.getLogger("azureml").setLevel(logging.DEBUG)
logging.getLogger("azure").setLevel(logging.DEBUG)
logging.getLogger("azure.core").setLevel(logging.DEBUG)
logging.getLogger("azure.mlflow").setLevel(logging.DEBUG)


logger = logging.getLogger(__name__)
env_dict = {
    key: os.environ[key]
    for key in ("MASTER_ADDR", "MASTER_PORT","LOCAL_RANK", "RANK", "WORLD_SIZE")
}

logger.info("os.getpid() is {} and initializing process group with {}".format(os.getpid(), env_dict))

# Set up the data loading parameters
batch_size = 128
num_epochs = 1
num_workers = 4
pin_memory = True

# Get the world size and rank to determine the process group
world_size = int(os.environ['WORLD_SIZE'])
world_rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])

print("World size:", world_size)
print("local rank is {} and world rank is {}".format(local_rank, world_rank))


print("PyTorch version is {} and torchvision version is {}".format(torch.__version__, torchvision.__version__))
is_distributed = world_size > 1

if is_distributed:
    batch_size = batch_size // world_size
    batch_size = max(batch_size, 1)

# Set the backend to NCCL for distributed training
dist.init_process_group(backend="nccl",
                        init_method="env://",
                        world_size=world_size,
                        rank=world_rank,
                        timeout=timedelta(seconds=2000))

# Set the device to the current local rank
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)

dist.barrier()

# Define the transforms for the dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# Load the CIFAR-10 dataset

# data_root = './data_' + str(world_rank)
data_root = opt.data
train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=False, transform=transform_train)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=world_rank, shuffle=True) if is_distributed else None
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=num_workers, pin_memory=pin_memory, sampler=train_sampler)

test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=False, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

# Define the ResNet50 model
model = torchvision.models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# Move the model to the GPU
model = model.to(device)

# Wrap the model with DistributedDataParallel
if is_distributed:
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # later check which one yields better results?

# Train the model for the specified number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    train_sampler.set_epoch(epoch) ### why is this line necessary??
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    #print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss))
        if batch_idx % 100 == 99: #print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' % 
                  (epoch +1, batch_idx+1, running_loss / 100))
            running_loss = 0.0
    if world_rank == 0:
        # Log the loss and running loss as MLFlow metrics
        mlflow.log_metric("loss", loss.item())
        mlflow.log_metric("running loss", running_loss)
        
print("Finished training!")

dist.barrier()
# Save the trained model
if world_rank == 0:
    checkpoints_path = "train_checkpoints"
    os.makedirs(checkpoints_path, exist_ok=True)
    torch.save(model.state_dict(), '{}/{}-{}.pth'.format(checkpoints_path, 'resnet50_cifar10', world_rank))
    mlflow.pytorch.log_model(model, "resnet50_cifar10_{}.pth".format(world_rank))
    # mlflow.log_artifact('{}/{}-{}.pth'.format(checkpoints_path, 'resnet50_cifar10', world_rank), artifact_path="model_state_dict")

    # Evaluate the model on the test set and save inference on 6 random images
    correct = 0
    total = 0
    with torch.no_grad():
        fig, axs = plt.subplots(2, 3, figsize=(8, 6), dpi=100)
        axs = axs.flatten()
        count = 0
        for data in test_loader:
            if count == 6:
                break
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Save the inference on the 6 random images
            if count < 6:
                image = np.transpose(inputs[0].cpu().numpy(), (1, 2, 0))
                confidence = torch.softmax(outputs, dim=1)[0][predicted[0]].cpu().numpy()
                class_name = test_dataset.classes[predicted[0]]
                axs[count].imshow(image)
                axs[count].set_title(f'Class: {class_name}\nConfidence: {confidence:.2f}')
                axs[count].axis('off')
                count += 1
            
            
    test_accuracy = 100 * correct / total
    print('Test accuracy: %.2f %%' % test_accuracy)

# # Average the test accuracy across all processes

# correct = torch.tensor(correct, dtype=torch.int8)
# correct = correct.to(device)
# torch.distributed.all_reduce(correct, op=torch.distributed.ReduceOp.SUM)
# total = torch.tensor(total, dtype=torch.torch.int8)
# total = total.to(device)
# torch.distributed.all_reduce(total, op=torch.distributed.ReduceOp.SUM)
# test_accuracy = 100 * correct / total
# test_accuracy /= world_size

# print('Test accuracy: %.2f %%' % test_accuracy)

# Save the plot with the 6 random images and their predicted classes and prediction confidence
test_img_file_name = 'test_images_' + str(world_rank) + '.png'
plt.savefig(test_img_file_name)

# Log the test accuracy and elapsed time to MLflow
if world_rank == 0:
    mlflow.log_metric("test accuracy", test_accuracy)

end_time = time.time()
elapsed_time = end_time - start_time
print('Elapsed time: ', elapsed_time)
if world_rank == 0:
    mlflow.log_metric("elapsed time", elapsed_time)

if world_rank == 0:
    # Save the plot with the 6 random images and their predicted classes and prediction confidence as an artifact in MLflow
    image = Image.open(test_img_file_name)
    image = image.convert('RGBA')
    image_buffer = np.array(image)
    image_buffer = image_buffer[:, :, [2, 1, 0, 3]]
    image_buffer = np.ascontiguousarray(image_buffer)
    artifact_file_name = "inference_on_test_images_" + str(world_rank) + ".png"
    mlflow.log_image(image_buffer, artifact_file=artifact_file_name)

# End the MLflow run
if mlflow.active_run():
    mlflow.end_run()

dist.destroy_process_group()

Hey, I am also getting a similar error. Were you able to find a solution for the same?