DataParallel raises infinite freezing with A100 gpus

I am trying to run DDP or DP with two A100 gpus on server, but whenever the machine accesses the computed loss, the process freezes infinitely. I have to exit the process with ctrl+z and manually kill the process.
I made a simple script which trains resnet18 on CIFAR-10, with or without DP.
The script runs smoothly when DP is disabled. (on both A100 and non-A100 GPUs)
Additionally, the same script runs fine on other servers with 3090 GPUs.
The problem happens when DP is enabled and trained on A100 GPUs.
Therefore, I think that pytorch’s DP or DDP does not work with A100 GPU.

Here is my server spec:
OS: Ubuntu 20.04 Server (updated to the latest)
CPU: Intel Xeon Gold 6226R x2
RAM: 256GB
GPUs: A100 x2

And this is software spec:
anaconda: 4.12.0
pytorch: 1.12 + cuda 11.6
cuda: 11.6
cudnn: 8.4.1

# https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py
import torch
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

DP = True
#DP = False

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4
num_workers = 4

cudnn.benchmark = True
torch.backends.cudnn.enabled = True

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers,
                                          pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=num_workers)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

class ResNetWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.loss = nn.CrossEntropyLoss()
    def forward(self, x, gt=None):
        output = self.resnet(x)

        if gt is not None:
            return self.loss(output, gt)
        return output


print('Initializing Model')
#net = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
net = ResNetWrapper()

if DP:
    net = torch.nn.DataParallel(net).cuda()
else:
    net = net.cuda()

#criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.001)

print("Start Training")
first_iteration = True
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    pbar = tqdm(total=len(trainloader))
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        if not DP:
            inputs = inputs.cuda()
            labels = labels.cuda()

        if first_iteration:
            print("passing data to device")

        # zero the parameter gradients
        optimizer.zero_grad()

        if first_iteration:
            print("forwarding input")

        # forward + backward + optimize

        loss = net(inputs, labels)
        loss = loss.sum()

        if first_iteration:
            print("backwarding loss")
        loss.backward()

        if first_iteration:
            print("updating with optimizer")
            first_iteration = False
        optimizer.step()
        """
        # print statistics
        loss_val = loss.detach().cpu().item()
        running_loss += loss_val
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
        """
        pbar.update(1)
    pbar.close()

print('Finished Training')

########################################################################
# Let's quickly save our trained model:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

net = ResNetWrapper()
net.load_state_dict(torch.load(PATH))

correct = 0
total = 0
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

# %%%%%%INVISIBLE_CODE_BLOCK%%%%%%
del dataiter

Check if IOMMU is enabled and disable it to avoid hangs.

Thanks. I set ‘Intel Virtualization Technology’ disabled in BIOS, and the problem is solved.

I ran into the same problem once again, and turns out that disabling Intel VT is not a complete solution.
PLX PCI bridges still enable ACS.
So, you should update to the BIOS version which can solve this issue, or manually figure out PCI bridges and disable ACS one by one.
I managed to disable all ACS thanks to the instruction below: