1000x difference in time taken inside forward and model(inputs) call with MPS with synchronize

Hi there. I’ve been scratching my head for quite some time because I am observing a big difference in the time taken when measured (with mps.synchronize when needed) inside a model’s forward method, and around the model(inputs) call. I’m seeing this difference only when I use MPS in my macbook pro for the pytorch code. I don’t observe this when I use cpu.

I would love to understand if my expectation of the two times to be the same is really wrong?

This is my understanding - inside the model’s forward call, the entire data gets consumed. Low level operations might be parallelized, but there’s just one forward call so there’s no reason why time measured outside forward() should be significantly different from time measured inside forward().

Here’s a minimal example to reproduce:

import torch
from torch import mps
from timeit import default_timer as timer
from torch.utils.data import random_split, DataLoader
from torch.utils.data import Dataset, random_split, DataLoader

max_train_steps = 30
log_interval = 10
size = (256, 256)
device = torch.device("mps")
# device = torch.device("cpu")
print('Using device:', device)

def timeit(device):
    if device == torch.device("mps"):
    return timer()

# Dataset 
class FakeDataset(Dataset):
    def __init__(self):
        self.x = 1
    def __len__(self):
        return 1000
    def __getitem__(self, idx):
        m1 = torch.rand(8192, 2048)
        m2 = torch.rand(2048, 8192)
        return m1, m2

# Model
class FakeModel(torch.nn.Module):
    def __init__(self):
        super(FakeModel, self).__init__()
        self.x = torch.nn.Linear(1,1, bias=False)

    def forward(self, m1, m2):
        # Some naive model that just multiplies tensors and has a linear 1 -> 1 layer.
        device = m1.device
        t1 = timeit(device)
        val = torch.matmul(m1, m2)
        loss = self.x(val.sum().unsqueeze(0))
        logits = 15.8
        t2 = timeit(device)
        print(f'In model: total time = {t2-t1}')
        return logits, loss

dataset = FakeDataset()
train_ds, test_ds = random_split(dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_ds, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=8, shuffle=True)
model = FakeModel().to(device)

for batch_idx, (data, target) in enumerate(train_dataloader):
    if batch_idx >= 20:
    data, target = data.to(device), target.to(device)
    forward_start_time = timeit(device)
    _, loss = model(data, target)
    forward_end_time = timeit(device)
    print(f'In train loop: time around model forward = {forward_end_time - forward_start_time}')

Sample output seen is:

In model: total time = 0.0005140830180607736
In train loop: time around model forward = 0.5661588329821825

that is, time measured inside around model(data,target) is around 1101 times the time measured inside. They should be the same unless I’m missing something fundamental