Extremly high CPU on x86/CUDA vs no CPU usage on M1/MPS on the same CNN test code

I run the test code bellow on two Ubuntu LTS systems with 24/32cores and A30/A6000 GPUs and the CPU usage during the training loop is around 70%++ on ALL cores!

The same code with device=“mps” on a M1 uses one core to around 30-50%.

What am I missing?! (fyi Im not expecting the model to be a good model!! Im worried about the performance of this code!!)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torchvision import transforms
from torchvision.datasets import ImageFolder

#device = "mps"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# %%
X = np.random.rand(10000, 200, 200).astype(np.float32) 
X_tensor = torch.tensor(X[:, None, :, :])

def calculate_mean_std(loader):
    mean = 0.0
    variance = 0.0
    total_images = 0
    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        variance += images.var(2).sum(0)
        total_images += batch_samples
    mean /= total_images
    std = torch.sqrt(variance / total_images)
    return mean, std

temp_loader = DataLoader(TensorDataset(X_tensor, torch.zeros(len(X_tensor))), batch_size=64, shuffle=False)

mean, std = calculate_mean_std(temp_loader)
print(f'Mean: {mean}, Standard Deviation: {std}')


# %%
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# %%
model = models.squeezenet1_1(pretrained=True)

model.features[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=3)

model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Conv2d(512, 1, kernel_size=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1))
)

def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x.view(x.size(0))

model.forward = forward.__get__(model)
model.to(device)

# %%
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


y = np.random.rand(10000).astype(np.float32)

full_dataset = CustomDataset(X, y, transform=None) 

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_dataset.dataset.transform = train_transforms
test_dataset.dataset.transform = test_transforms

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


# %%
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# %%
train_losses = []
val_losses = []

# %%
num_epochs = 10 

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inputs.size(0)
    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)

    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_inputs, val_targets in test_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item() * val_inputs.size(0)
    val_loss /= len(test_loader.dataset)
    val_losses.append(val_loss)

    print(f'Epoch {epoch + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}')