Greetings,
I have this code
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import StepLR
from torch.utils import data
from torchvision import models, datasets
import argparse
from tqdm import tqdm
precisions = {
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16
}
def train(epoch, device, train_loader, model, loss_criterion, optimizer):
model.train()
scaler = GradScaler()
train_loader = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
with autocast(dtype=precisions[args.precision]):
output = model(input)
loss = loss_criterion(output, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loader.set_postfix({'Loss': loss.item()}, refresh=False)
train_loader.close()
def validate(epoch, device, validation_loader, model, loss_criterion):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for inputs, targets in validation_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
test_loss += loss_criterion(outputs, targets).item()
_, predicted = outputs.max(1)
correct += predicted.eq(targets).sum().item()
test_loss /= len(validation_loader.dataset)
accuracy = 100. * correct / len(validation_loader.dataset)
print(f'\nValidation. Epoch: {epoch}, Loss: {test_loss:.4f}, Accuracy: ({accuracy:.2f}%)\n')
def cleanup():
dist.destroy_process_group()
def train_dataloader(args, rank, world_size):
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_sampler = data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
train_loader = data.DataLoader(dataset, batch_size=args.batch_size, num_workers=1,
persistent_workers=True, sampler=train_sampler)
return train_loader
def val_dataloader(args):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
val_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=1,
persistent_workers=True)
return val_loader
def main(args, rank, world_size):
torch.manual_seed(0)
train_loader = train_dataloader(args, rank, world_size)
validation_loader = val_dataloader(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Rank {rank}: Using CUDA device {device}")
model = models.resnet18(weights=None, num_classes=10).to(device)
ddp_model = DistributedDataParallel(model)
loss_criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(ddp_model.parameters(), lr=0.01, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=1, gamma=0.05)
for epoch in range(args.epochs):
train(epoch, device, train_loader, ddp_model, loss_criterion, optimizer)
validate(epoch, device, validation_loader, ddp_model, loss_criterion)
scheduler.step()
cleanup()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--precision", type=str, default="fp32")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
try:
torch.distributed.init_process_group(backend='gloo', init_method='env://')
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
main(args, rank, world_size)
except Exception as e:
print("An error occurred:", e)
raise
It raises this error
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
An error occurred: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
Traceback (most recent call last):
File "/home/3458/pytorch/cifar10_pytorch.py", line 125, in <module>
main(args, rank, world_size)
File "/home/3458/pytorch/cifar10_pytorch.py", line 105, in main
train(epoch, device, train_loader, ddp_model, loss_criterion, optimizer)
File "/home/3458/pytorch/cifar10_pytorch.py", line 29, in train
with autocast(dtype=precisions[args.precision]):
File "/home/3458/pytorch/venv/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 34, in __init__
super().__init__(
File "/home/3458/pytorch/venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 306, in __init__
raise RuntimeError(
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
While when I run this code,
import torch
x = torch.randn(3, 3, dtype=torch.bfloat16)
print(x)
it outputs
tensor([[ 0.8203, 2.1562, -1.8047],
[ 0.0879, 0.2354, -1.0781],
[-1.0469, -0.8984, -1.7656]], dtype=torch.bfloat16)
Nvidia-SMI output
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla T4 Off | 00000000:81:00.0 Off | 0 |
| N/A 47C P0 28W / 70W | 0MiB / 15360MiB | 5% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Can you advise?
Best