Given some m-by-m covariance matrix C with missing entries, I would like to find a low rank m-by-k matrix L where k << m so that LL^T is a good approximation to C. To do so, I’ve coded a distributed version of stochastic gradient descent using Pytorch’s distributed framework.
I noticed that for a fixed number of observed points in C there is a dramatic increase in the time reqired complete a backward pass after m exceeds some threshold. I profiled my code and realized that after m reaches this threshold, the CPU time avg
for calls to aten::fill_
and aten::add_
during the backward pass spike then plateau. Below is a minimum working example demonstrating this behavior.
# mwe.py
import argparse
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.functional import mse_loss
from typing import Callable
class LowRankCovariance(nn.Module):
def __init__(
self,
num_vars: int,
num_facs: int
):
super().__init__()
self.num_facs = num_facs
self.loads = torch.randn(num_vars, num_facs, dtype=torch.float64)
self.loads.requires_grad_()
self.loads = nn.Parameter(self.loads)
def forward(self, idx0, idx1):
return (self.loads[idx0] * self.loads[idx1]).sum(dim=1)
def init_process(
rank: int,
world_size: int,
fcn: Callable,
path_shared: str,
num_vars: int,
num_facs: int,
batch_size: int,
lr: float,
seed: int
):
dist.init_process_group(
'gloo', init_method=f'file://{path_shared}',
rank=rank, world_size=world_size
)
fcn(rank, num_vars, num_facs, batch_size, lr, seed)
def train(
rank: int,
num_vars: int,
num_facs: int,
batch_size: int,
lr: float,
seed: int
) -> None:
gen = torch.Generator().manual_seed(seed)
# Generate fake data
sz = 10000
points1 = torch.randint(0, num_vars, (sz,), dtype=torch.int32, generator=gen)
points2 = torch.randint(0, num_vars, (sz,), dtype=torch.int32, generator=gen)
points = torch.column_stack([points1, points2])
cov = torch.randn(sz, dtype=torch.float64, generator=gen)
model = LowRankCovariance(num_vars, num_facs)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
with torch.autograd.profiler.profile(enabled=True) as prof:
cnt = 0
while cnt < sz:
# Get batch of data
last_idx = min(cnt, cnt + batch_size)
points_batch = points[cnt:last_idx]
cov_batch = cov[cnt:last_idx]
cnt += batch_size
# Forward pass
preds = model(points_batch[:,0], points_batch[:,1])
loss = mse_loss(preds, cov_batch)
# Backward pass
optimizer.zero_grad()
with torch.autograd.profiler.record_function('backward_pass'):
loss.backward()
optimizer.step()
if rank == 0:
script_dir = os.path.dirname(os.path.realpath(__file__))
prof.export_chrome_trace(f'{script_dir}/profile_{num_vars}.json')
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
dist.destroy_process_group()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_vars', type=int)
parser.add_argument('--world_size', type=int)
parser.add_argument('--num_facs', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--lr', type=float)
parser.add_argument('--seed', type=int, default=12345)
args = parser.parse_args()
# ---------- DISTRIBUTED TRAINING ---------- #
# Refresh file for shared file-system initialization
script_dir = os.path.dirname(os.path.realpath(__file__))
path_shared = os.path.join(script_dir, 'shared')
if os.path.exists(path_shared):
os.remove(path_shared)
mp.set_start_method('spawn')
processes = []
for rank in range(args.world_size):
p = mp.Process(
target=init_process,
args=(
rank, args.world_size, train, path_shared,
args.num_vars, args.num_facs, args.batch_size,
args.lr, args.seed + rank
)
)
p.start()
processes.append(p)
for p in processes:
p.join()
If you run the script five times (see commands below) with m = 8190, 8191, 8192, 8193, and 8194, you’ll notice that the average CPU times for aten::fill_
and aten:add_
start low, increase slightly at m = 8192 then increase dramatically at m = 8193 and remain high.
$ python mwe.py --num_vars 8190 --world_size 5 --num_facs 4 --batch_size 512 --lr 0.1
$ python mwe.py --num_vars 8191 --world_size 5 --num_facs 4 --batch_size 512 --lr 0.1
$ python mwe.py --num_vars 8192 --world_size 5 --num_facs 4 --batch_size 512 --lr 0.1
$ python mwe.py --num_vars 8193 --world_size 5 --num_facs 4 --batch_size 512 --lr 0.1
$ python mwe.py --num_vars 8194 --world_size 5 --num_facs 4 --batch_size 512 --lr 0.1
Below are two screenshots of profile traces for m = 8191 and m = 8193.
Notice that in the m = 8193 trace, there are extended calls to aten::fill_
and aten::add_
that are not present in the m = 8191 trace.
Environment details:
- OS: Red Hat Enterprise Linux 7
- Python version: 3.10.12
- Pytorch version: 2.1.2
Other comments:
- I was also able to replicate this slowdown locally (MacOS Ventura 13.2) using
mwe.py
, although it was not as pronounced. - The slowdown worsens as the world size increases.
Does anyone know what causes this slowdown and how I can avoid it?