Problem with model = torchvision.models.resnet50(pretrained=True) in multinode multi GPU setting with distributed data parallel

I get this error when I use STANDARD_NC24 cluster with 4 nodes each having 4 K80 GPUs in Microsoft Azure (enterprise). I am using DistributedDataParallel using NCCL backend and eth0. I do pipeline the job hence I use env:// for dist_url for init_process_group.

5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO Bootstrap : Using eth0:10.0.0.5<0>
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO P2P plugin IBext
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1.
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NET/IB : No device found.
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO NET/Socket : Using [0]eth0:10.0.0.5<0>
5225797bbbfa41b58e9cac81360abc94000001:47:47 [3] NCCL INFO Using network Socket
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0001-0000-3130-444531303244/pci0001:00/0001:00:00.0/../max_link_speed, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0001-0000-3130-444531303244/pci0001:00/0001:00:00.0/../max_link_width, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0002-0000-3130-444531303244/pci0002:00/0002:00:00.0/../max_link_speed, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0002-0000-3130-444531303244/pci0002:00/0002:00:00.0/../max_link_width, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0003-0000-3130-444531303244/pci0003:00/0003:00:00.0/../max_link_speed, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0003-0000-3130-444531303244/pci0003:00/0003:00:00.0/../max_link_width, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0004-0000-3130-444531303244/pci0004:00/0004:00:00.0/../max_link_speed, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/47505500-0004-0000-3130-444531303244/pci0004:00/0004:00:00.0/../max_link_width, ignoring
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Topology detection: network path /sys/devices/LNXSYSTM:00/LNXSYBUS:00/PNP0A03:00/device:07/VMBUS:01/6045bd7e-4468-6045-bd7e-44686045bd7e is not a PCI device (vmbus). Attaching to first CPU
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO KV Convert to int : could not find value of '' in dictionary, falling back to 60
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Attribute coll of node net not found
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO === System : maxWidth 5.0 totalWidth 12.0 ===
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO CPU/0 (1/1/1)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO + PCI[5000.0] - NIC/0
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO                 + NET[5.0] - NET/0 (0/0/5.000000)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO + PCI[12.0] - GPU/100000 (4)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO + PCI[12.0] - GPU/200000 (5)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO + PCI[12.0] - GPU/300000 (6)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO + PCI[12.0] - GPU/400000 (7)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO ==========================================
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO GPU/100000 :GPU/100000 (0/5000.000000/LOC) GPU/200000 (2/12.000000/PHB) GPU/300000 (2/12.000000/PHB) GPU/400000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO GPU/200000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (0/5000.000000/LOC) GPU/300000 (2/12.000000/PHB) GPU/400000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO GPU/300000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (2/12.000000/PHB) GPU/300000 (0/5000.000000/LOC) GPU/400000 (2/12.000000/PHB) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO GPU/400000 :GPU/100000 (2/12.000000/PHB) GPU/200000 (2/12.000000/PHB) GPU/300000 (2/12.000000/PHB) GPU/400000 (0/5000.000000/LOC) CPU/0 (1/12.000000/PHB) NET/0 (3/5.000000/PHB)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO NET/0 :GPU/100000 (3/5.000000/PHB) GPU/200000 (3/5.000000/PHB) GPU/300000 (3/5.000000/PHB) GPU/400000 (3/5.000000/PHB) CPU/0 (2/5.000000/PHB) NET/0 (0/5000.000000/LOC)
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Pattern 4, crossNic 0, nChannels 1, speed 5.000000/5.000000, type PHB/PHB, sameChannels 1
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO  0 : NET/0 GPU/4 GPU/5 GPU/6 GPU/7 NET/0
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Pattern 1, crossNic 0, nChannels 1, speed 6.000000/5.000000, type PHB/PHB, sameChannels 1
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO  0 : NET/0 GPU/4 GPU/5 GPU/6 GPU/7 NET/0
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Pattern 3, crossNic 0, nChannels 0, speed 0.000000/0.000000, type NVL/PIX, sameChannels 1
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Ring 00 : 6 -> 7 -> 8
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Ring 01 : 6 -> 7 -> 8
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Trees [0] -1/-1/-1->7->6 [1] -1/-1/-1->7->6
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Setting affinity for GPU 3 to 0fff
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Channel 00 : 7[400000] -> 8[100000] [send] via NET/Socket/0
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Channel 01 : 7[400000] -> 8[100000] [send] via NET/Socket/0
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Connected all rings
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Channel 00 : 7[400000] -> 6[300000] via direct shared memory
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Channel 01 : 7[400000] -> 6[300000] via direct shared memory
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO Connected all trees
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO threadThresholds 8/8/64 | 128/8/64 | 8/8/512
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO 2 coll channels, 2 p2p channels, 1 p2p channels per peer
5225797bbbfa41b58e9cac81360abc94000001:47:213 [3] NCCL INFO comm 0x14ec70001240 rank 7 nranks 16 cudaDev 3 busId 400000 - Init COMPLETE
NCCL version is:  (2, 10, 3)
MLflow version: 2.3.2
Tracking URI: azureml:URI
Artifact URI: azureml:URI
World size: 16
local rank is 3 and world rank is 7
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data_7/cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data_7/cifar-10-python.tar.gz

  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 509952/170498071 [00:00<00:34, 4973408.07it/s]
  3%|▎         | 5369856/170498071 [00:00<00:05, 30292895.97it/s]
  6%|▌         | 10510336/170498071 [00:00<00:04, 39722223.67it/s]
  9%|▉         | 15607808/170498071 [00:00<00:03, 44134683.65it/s]
 12%|█▏        | 20916224/170498071 [00:00<00:03, 47346544.93it/s]
 15%|█▌        | 25932800/170498071 [00:00<00:02, 48219144.33it/s]
 18%|█▊        | 31144960/170498071 [00:00<00:02, 49478225.78it/s]
 21%|██        | 36150272/170498071 [00:00<00:02, 49594746.51it/s]
 24%|██▍       | 41290752/170498071 [00:00<00:02, 50114658.72it/s]
 27%|██▋       | 46394368/170498071 [00:01<00:02, 50398280.09it/s]
 30%|███       | 51577856/170498071 [00:01<00:02, 50816752.04it/s]
 33%|███▎      | 56827904/170498071 [00:01<00:02, 51320451.98it/s]
 36%|███▋      | 61961216/170498071 [00:01<00:02, 50976453.02it/s]
 39%|███▉      | 67060736/170498071 [00:01<00:02, 50287582.50it/s]
 42%|████▏     | 72092672/170498071 [00:01<00:01, 50109766.08it/s]
 45%|████▌     | 77106176/170498071 [00:01<00:01, 49095767.34it/s]
 48%|████▊     | 82021376/170498071 [00:01<00:01, 48671364.32it/s]
 51%|█████     | 86892544/170498071 [00:01<00:01, 47545196.01it/s]
 54%|█████▍    | 91653120/170498071 [00:01<00:01, 47161788.29it/s]
 57%|█████▋    | 96468992/170498071 [00:02<00:01, 47452315.59it/s]
 59%|█████▉    | 101244928/170498071 [00:02<00:01, 47427661.77it/s]
 62%|██████▏   | 105990144/170498071 [00:02<00:01, 47410171.20it/s]
 65%|██████▍   | 110733312/170498071 [00:02<00:01, 46584731.58it/s]
 68%|██████▊   | 115395584/170498071 [00:02<00:01, 44587269.02it/s]
 70%|███████   | 119872512/170498071 [00:02<00:01, 43861346.97it/s]
 73%|███████▎  | 124271616/170498071 [00:02<00:01, 42668794.97it/s]
 75%|███████▌  | 128549888/170498071 [00:02<00:00, 42252685.08it/s]
 80%|███████▉  | 135838720/170498071 [00:02<00:00, 51047603.65it/s]
 85%|████████▍ | 144263168/170498071 [00:02<00:00, 60702249.91it/s]
 90%|████████▉ | 152733696/170498071 [00:03<00:00, 67747468.16it/s]
 94%|█████████▍| 161063936/170498071 [00:03<00:00, 72342864.86it/s]
 99%|█████████▉| 169494528/170498071 [00:03<00:00, 75891572.88it/s]
170499072it [00:03, 51834695.76it/s]                              
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
Extracting ./data_7/cifar-10-python.tar.gz to ./data_7
Files already downloaded and verified
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/urllib/request.py", line 1354, in do_open
    h.request(req.get_method(), req.selector, req.data, headers,
  File "/opt/conda/lib/python3.8/http/client.py", line 1256, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "/opt/conda/lib/python3.8/http/client.py", line 1302, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "/opt/conda/lib/python3.8/http/client.py", line 1251, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/opt/conda/lib/python3.8/http/client.py", line 1011, in _send_output
    self.send(msg)
  File "/opt/conda/lib/python3.8/http/client.py", line 951, in send
    self.connect()
  File "/opt/conda/lib/python3.8/http/client.py", line 1418, in connect
    super().connect()
  File "/opt/conda/lib/python3.8/http/client.py", line 922, in connect
    self.sock = self._create_connection(
  File "/opt/conda/lib/python3.8/socket.py", line 787, in create_connection
    for res in getaddrinfo(host, port, 0, SOCK_STREAM):
  File "/opt/conda/lib/python3.8/socket.py", line 918, in getaddrinfo
    for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
socket.gaierror: [Errno -3] Temporary failure in name resolution

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train.py", line 93, in <module>
    model = torchvision.models.resnet50(pretrained=True)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/resnet.py", line 331, in resnet50
    return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/resnet.py", line 296, in _resnet
    state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
  File "/opt/conda/lib/python3.8/site-packages/torch/hub.py", line 591, in load_state_dict_from_url
    download_url_to_file(url, cached_file, hash_prefix, progress=progress)
  File "/opt/conda/lib/python3.8/site-packages/torch/hub.py", line 457, in download_url_to_file
    u = urlopen(req)
  File "/opt/conda/lib/python3.8/urllib/request.py", line 222, in urlopen
    return opener.open(url, data, timeout)
  File "/opt/conda/lib/python3.8/urllib/request.py", line 525, in open
    response = self._open(req, data)
  File "/opt/conda/lib/python3.8/urllib/request.py", line 542, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "/opt/conda/lib/python3.8/urllib/request.py", line 502, in _call_chain
    result = func(*args)
  File "/opt/conda/lib/python3.8/urllib/request.py", line 1397, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "/opt/conda/lib/python3.8/urllib/request.py", line 1357, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

The code is:

import time
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import mlflow
import os
import datetime

import configparser
import logging
import argparse

from PIL import Image

import ssl
ssl._create_default_https_context = ssl._create_unverified_context


start_time = time.time()

torch.backends.cudnn.benchmark=False
torch.backends.cudnn.deterministic=True


print("NCCL version is: ", torch.cuda.nccl.version())
print("MLflow version:", mlflow.__version__)
print("Tracking URI:", mlflow.get_tracking_uri())
print("Artifact URI:", mlflow.get_artifact_uri())

# Set the seed for reproducibility
torch.manual_seed(42)

# Set up the data loading parameters
batch_size = 128
num_epochs = 10
num_workers = 4
pin_memory = True

# Get the world size and rank to determine the process group
world_size = int(os.environ['WORLD_SIZE'])
world_rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])

print("World size:", world_size)
print("local rank is {} and world rank is {}".format(local_rank, world_rank))

is_distributed = world_size > 1

if is_distributed:
    batch_size = batch_size // world_size
    batch_size = max(batch_size, 1)

# Set the backend to NCCL for distributed training
dist.init_process_group(backend="nccl",
                        init_method="env://",
                        world_size=world_size,
                        rank=world_rank)

# Set the device to the current local rank
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)

dist.barrier()

# Define the transforms for the dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# Load the CIFAR-10 dataset

data_root = './data_' + str(world_rank)
train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=world_rank, shuffle=True) if is_distributed else None
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=num_workers, pin_memory=pin_memory, sampler=train_sampler)

test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

# Define the ResNet50 model
model = torchvision.models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# Move the model to the GPU
model = model.to(device)

# Wrap the model with DistributedDataParallel
if is_distributed:
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model for the specified number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    train_sampler.set_epoch(epoch) ### why is this line necessary??
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss))
    if world_rank == 0:
        # Log the loss and running loss as MLFlow metrics
        mlflow.log_metric("loss", loss.item())
        mlflow.log_metric("running loss", running_loss)

dist.barrier()
# Save the trained model
if world_rank == 0:
    checkpoints_path = "train_checkpoints"
    os.makedirs(checkpoints_path, exist_ok=True)
    torch.save(model.state_dict(), '{}/{}-{}.pth'.format(checkpoints_path, 'resnet50_cifar10', world_rank))
    mlflow.pytorch.log_model(model, "resnet50_cifar10_{}.pth".format(world_rank))
    # mlflow.log_artifact('{}/{}-{}.pth'.format(checkpoints_path, 'resnet50_cifar10', world_rank), artifact_path="model_state_dict")

# Evaluate the model on the test set and save inference on 6 random images
correct = 0
total = 0
with torch.no_grad():
    fig, axs = plt.subplots(2, 3, figsize=(8, 6), dpi=100)
    axs = axs.flatten()
    count = 0
    for data in test_loader:
        if count == 6:
            break
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Save the inference on the 6 random images
        if count < 6:
            image = np.transpose(inputs[0].cpu().numpy(), (1, 2, 0))
            confidence = torch.softmax(outputs, dim=1)[0][predicted[0]].cpu().numpy()
            class_name = test_dataset.classes[predicted[0]]
            axs[count].imshow(image)
            axs[count].set_title(f'Class: {class_name}\nConfidence: {confidence:.2f}')
            axs[count].axis('off')
            count += 1

# Average the test accuracy across all processes

correct = torch.tensor(correct, dtype=torch.int8)
correct = correct.to(device)
torch.distributed.all_reduce(correct, op=torch.distributed.ReduceOp.SUM)
total = torch.tensor(total, dtype=torch.torch.int8)
total = total.to(device)
torch.distributed.all_reduce(total, op=torch.distributed.ReduceOp.SUM)
test_accuracy = 100 * correct / total
test_accuracy /= world_size

print('Test accuracy: %.2f %%' % test_accuracy)

# Save the plot with the 6 random images and their predicted classes and prediction confidence
test_img_file_name = 'test_images_' + str(world_rank) + '.png'
plt.savefig(test_img_file_name)

# Log the test accuracy and elapsed time to MLflow
if world_rank == 0:
    mlflow.log_metric("test accuracy", test_accuracy)

end_time = time.time()
elapsed_time = end_time - start_time
print('Elapsed time: ', elapsed_time)
if world_rank == 0:
    mlflow.log_metric("elapsed time", elapsed_time)

# Save the plot with the 6 random images and their predicted classes and prediction confidence as an artifact in MLflow
image = Image.open(test_img_file_name)
image = image.convert('RGBA')
image_buffer = np.array(image)
image_buffer = image_buffer[:, :, [2, 1, 0, 3]]
image_buffer = np.ascontiguousarray(image_buffer)
artifact_file_name = "inference_on_test_images_" + str(world_rank) + ".png"
mlflow.log_image(image_buffer, artifact_file=artifact_file_name)

# End the MLflow run
if mlflow.active_run():
    mlflow.end_run()

dist.destroy_process_group()

Could you please guide me as how to fix this error?