Hi there. I’ve been scratching my head for quite some time because I am observing a big difference in the time taken when measured (with mps.synchronize when needed) inside a model’s forward method, and around the model(inputs) call. I’m seeing this difference only when I use MPS in my macbook pro for the pytorch code. I don’t observe this when I use cpu.
I would love to understand if my expectation of the two times to be the same is really wrong?
This is my understanding - inside the model’s forward call, the entire data gets consumed. Low level operations might be parallelized, but there’s just one forward call so there’s no reason why time measured outside forward() should be significantly different from time measured inside forward().
Here’s a minimal example to reproduce:
import torch
from torch import mps
from timeit import default_timer as timer
from torch.utils.data import random_split, DataLoader
from torch.utils.data import Dataset, random_split, DataLoader
max_train_steps = 30
log_interval = 10
size = (256, 256)
device = torch.device("mps")
# device = torch.device("cpu")
print('Using device:', device)
def timeit(device):
if device == torch.device("mps"):
mps.synchronize()
return timer()
# Dataset
class FakeDataset(Dataset):
def __init__(self):
self.x = 1
def __len__(self):
return 1000
def __getitem__(self, idx):
m1 = torch.rand(8192, 2048)
m2 = torch.rand(2048, 8192)
return m1, m2
# Model
class FakeModel(torch.nn.Module):
def __init__(self):
super(FakeModel, self).__init__()
self.x = torch.nn.Linear(1,1, bias=False)
def forward(self, m1, m2):
# Some naive model that just multiplies tensors and has a linear 1 -> 1 layer.
device = m1.device
t1 = timeit(device)
val = torch.matmul(m1, m2)
loss = self.x(val.sum().unsqueeze(0))
logits = 15.8
t2 = timeit(device)
print("**********************")
print(f'In model: total time = {t2-t1}')
print("**********************")
return logits, loss
dataset = FakeDataset()
train_ds, test_ds = random_split(dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_ds, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=8, shuffle=True)
model = FakeModel().to(device)
for batch_idx, (data, target) in enumerate(train_dataloader):
if batch_idx >= 20:
break
data, target = data.to(device), target.to(device)
forward_start_time = timeit(device)
_, loss = model(data, target)
forward_end_time = timeit(device)
print(f'In train loop: time around model forward = {forward_end_time - forward_start_time}')
print('\n\n')
Sample output seen is:
**********************
In model: total time = 0.0005140830180607736
**********************
In train loop: time around model forward = 0.5661588329821825
that is, time measured inside around model(data,target)
is around 1101 times the time measured inside. They should be the same unless I’m missing something fundamental