In case it interests the developers, here is a notable example (finite difference wave propagation) where PyTorch, torch.compile
, and torch.jit.script
, are all substantially (5X+) slower than C++ and CUDA.
PyTorch:
import math
import torch
from torch import Tensor
from timeit import timeit
def ricker(freq: float, length: int, dt: float, peak_time: float) -> Tensor:
"""Ricker wavelet for source."""
t = torch.arange(length) * dt - peak_time
a = math.pi**2 * freq**2 * t**2
return (1 - 2 * a) * torch.exp(-a)
def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
one_over_dx2: float, one_over_dy2: float):
"""One finite difference time step of wave equation.
Args:
wfc: current time step wavefield, [shot, x, y]
wfp: previous time step wavefield
v2dt2: velocity^2 * dt^2
one_over_dx2/dy2: 1/dx^2
"""
d2wdx2 = (
-5 / 2 * wfc[:, 2:-2, 2:-2]
+ 4 / 3 * (wfc[:, 3:-1, 2:-2] + wfc[:, 1:-3, 2:-2])
+ -1 / 12 * (wfc[:, 4:, 2:-2] + wfc[:, :-4, 2:-2])
) * one_over_dx2
d2wdy2 = (
-5 / 2 * wfc[:, 2:-2, 2:-2]
+ 4 / 3 * (wfc[:, 2:-2, 3:-1] + wfc[:, 2:-2, 1:-3])
+ -1 / 12 * (wfc[:, 2:-2, 4:] + wfc[:, 2:-2, :-4])
) * one_over_dy2
# wavefield at next time step
return torch.nn.functional.pad(
v2dt2 * (d2wdx2 + d2wdy2)
+ 2 * wfc[:, 2:-2, 2:-2]
- wfp[:, 2:-2, 2:-2], (2, 2, 2, 2)
)
def closure(step_fn, device):
def run(n_shots: int = 10, ny: int = 200, nx: int = 200,
nt: int = 1000, dt: float = 0.0005,
dy: float = 5, dx: float = 5, freq: float = 25):
# Setup input velocity, source, receiver
v = 1500*torch.ones(ny, nx, device=device)
v2dt2 = v[2:-2, 2:-2]**2 * dt**2
wfc = torch.zeros(n_shots, ny, nx, device=device)
wfp = torch.zeros(n_shots, ny, nx, device=device)
one_over_dx2 = 1 / dx**2
one_over_dy2 = 1 / dy**2
source_amplitudes = (
ricker(freq, nt, dt, 1.5/freq)
.to(device)
.reshape(-1, 1)
.repeat(1, n_shots)
)
sources_batch = torch.arange(n_shots, dtype=torch.long,
device=device)
sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
receivers_batch = torch.arange(n_shots, dtype=torch.long,
device=device)
receivers_y = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
receivers_x = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
receiver_amplitudes = torch.zeros(nt, n_shots, device=device)
# Loop over time steps
for t in range(nt-1):
wfn = step_fn(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
wfn[sources_batch,
sources_y,
sources_x] += source_amplitudes[t]
receiver_amplitudes[t] = \
wfc[receivers_batch, receivers_x, receivers_y]
wfp, wfc = wfc, wfn
# Do something with output so compiler doesn't optimise away
return receiver_amplitudes.max()
return run
def timing(name, run_fn, n_warmup=3):
ans = 0
for _ in range(n_warmup):
ans = run_fn()
print(name, timeit(lambda: run_fn(), number=10), ans)
def eager():
with torch.no_grad():
timing('eager cpu', closure(step, torch.device('cpu')))
timing('eager gpu', closure(step, torch.device('cuda:0')))
def torchcompile():
with torch.no_grad():
timing('compile cpu', closure(torch.compile(step, fullgraph=True),
torch.device('cpu')))
timing('compile gpu', closure(torch.compile(step, fullgraph=True),
torch.device('cuda:0')))
def jitscript():
with torch.no_grad():
timing('jitscript cpu', closure(torch.jit.script(step),
torch.device('cpu')))
timing('jitscript gpu', closure(torch.jit.script(step),
torch.device('cuda:0')))
eager()
torchcompile()
jitscript()
C++:
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <algorithm>
namespace {
void step(float const *__restrict const wfc, float *__restrict const wfp,
float const *__restrict const v2dt2, float const one_over_dx2,
float const one_over_dy2, int const n_shots, int const ny,
int const nx) {
for (int shot = 0; shot < n_shots; ++shot) {
for (int y = 2; y < ny - 2; ++y) {
for (int x = 2; x < nx - 2; ++x) {
float d2wdx2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + (y + 1) * nx + x] +
wfc[shot * ny * nx + (y - 1) * nx + x]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + (y + 2) * nx + x] +
wfc[shot * ny * nx + (y - 2) * nx + x])) *
one_over_dx2;
float d2wdy2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + y * nx + x + 1] +
wfc[shot * ny * nx + y * nx + x - 1]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + y * nx + x + 2] +
wfc[shot * ny * nx + y * nx + x - 2])) *
one_over_dy2;
wfp[shot * ny * nx + y * nx + x] =
(v2dt2[y * nx + x] * (d2wdx2 + d2wdy2) +
2 * wfc[shot * ny * nx + y * nx + x] -
wfp[shot * ny * nx + y * nx + x]);
}
}
}
}
void add_sources(float *__restrict wf, float const *__restrict f,
int const *__restrict sources_i, int n_shots) {
for (int shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
wf[sources_i[shot_idx]] += f[shot_idx];
}
}
void record_receivers(float *__restrict r, float const *__restrict wf,
int const *__restrict receivers_i, int n_shots) {
for (int shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
r[shot_idx] = wf[receivers_i[shot_idx]];
}
}
float run() {
int n_shots = 10;
int ny = 200;
int nx = 200;
int nt = 1000;
float dt = 0.0005f;
float dy = 5.0f;
float dx = 5.0f;
float freq = 25.0f;
float peak_time = 1.5f / freq;
float *v2dt2 = (float *)malloc(ny * nx * sizeof(float));
float *wfc = (float *)calloc(n_shots * ny * nx, sizeof(float));
float *wfp = (float *)calloc(n_shots * ny * nx, sizeof(float));
float *source_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
float *receiver_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
int *source_locations = (int *)malloc(n_shots * sizeof(int));
int *receiver_locations = (int *)malloc(n_shots * sizeof(int));
float one_over_dx2 = 1.0f / (dx * dx);
float one_over_dy2 = 1.0f / (dy * dx);
for (int y = 0; y < ny; ++y) {
for (int x = 0; x < ny; ++x) {
v2dt2[y * nx + x] = 1500 * 1500 * dt * dt;
}
}
for (int i = 0; i < nt; ++i) {
float t = i * dt - peak_time;
float a = M_PI * freq * t;
a = a * a;
for (int shot = 0; shot < n_shots; ++shot) {
source_amplitudes[i * n_shots + shot] = (1 - 2 * a) * expf(-a);
}
}
for (int shot = 0; shot < n_shots; ++shot) {
source_locations[shot] = shot * ny * nx + 2 * nx + 2;
receiver_locations[shot] = shot * ny * nx + 2 * nx + 2;
}
for (int t = 0; t < nt; ++t) {
step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots, ny, nx);
add_sources(wfp, source_amplitudes + t * n_shots, source_locations,
n_shots);
record_receivers(receiver_amplitudes + t * n_shots, wfc, receiver_locations,
n_shots);
std::swap(wfc, wfp);
}
float ans = *std::max_element(receiver_amplitudes,
receiver_amplitudes + n_shots * nt);
free(v2dt2);
free(wfc);
free(wfp);
free(source_amplitudes);
free(receiver_amplitudes);
free(source_locations);
free(receiver_locations);
return ans;
}
} // namespace
int main(void) {
float ans = 0.0f;
// warmup
for (int i = 0; i < 3; ++i) {
ans = run();
}
// timing
clock_t t0 = clock();
for (int i = 0; i < 10; ++i) {
run();
}
clock_t t1 = clock();
printf("C++: %f %g\n", (float)(t1 - t0) / CLOCKS_PER_SEC, ans);
}
CUDA:
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <algorithm>
namespace {
__global__ void step(float const *__restrict const wfc,
float *__restrict const wfp,
float const *__restrict const v2dt2,
float const one_over_dx2, float const one_over_dy2,
int const n_shots, int const ny, int const nx) {
auto x{blockIdx.x * blockDim.x + threadIdx.x + 2};
auto y{blockIdx.y * blockDim.y + threadIdx.y + 2};
auto shot{blockIdx.z * blockDim.z + threadIdx.z};
if (y < ny - 2 and x < nx - 2) {
float d2wdx2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + (y + 1) * nx + x] +
wfc[shot * ny * nx + (y - 1) * nx + x]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + (y + 2) * nx + x] +
wfc[shot * ny * nx + (y - 2) * nx + x])) *
one_over_dx2;
float d2wdy2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + y * nx + x + 1] +
wfc[shot * ny * nx + y * nx + x - 1]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + y * nx + x + 2] +
wfc[shot * ny * nx + y * nx + x - 2])) *
one_over_dy2;
wfp[shot * ny * nx + y * nx + x] = (v2dt2[y * nx + x] * (d2wdx2 + d2wdy2) +
2 * wfc[shot * ny * nx + y * nx + x] -
wfp[shot * ny * nx + y * nx + x]);
}
}
__global__ void add_sources(float *__restrict wf, float const *__restrict f,
int const *__restrict sources_i, int n_shots) {
auto shot_idx{blockIdx.x * blockDim.x + threadIdx.x};
if (shot_idx < n_shots) {
wf[sources_i[shot_idx]] += f[shot_idx];
}
}
__global__ void record_receivers(float *__restrict r,
float const *__restrict wf,
int const *__restrict receivers_i,
int n_shots) {
auto shot_idx{blockIdx.x * blockDim.x + threadIdx.x};
if (shot_idx < n_shots) {
r[shot_idx] = wf[receivers_i[shot_idx]];
}
}
float run() {
int n_shots = 10;
int ny = 200;
int nx = 200;
int nt = 1000;
float dt = 0.0005f;
float dy = 5.0f;
float dx = 5.0f;
float freq = 25.0f;
float peak_time = 1.5f / freq;
float *v2dt2 = (float *)malloc(ny * nx * sizeof(float));
float *v2dt2_d;
float *wfc_d;
float *wfp_d;
float *source_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
float *source_amplitudes_d;
float *receiver_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
float *receiver_amplitudes_d;
int *source_locations = (int *)malloc(n_shots * sizeof(int));
int *source_locations_d;
int *receiver_locations = (int *)malloc(n_shots * sizeof(int));
int *receiver_locations_d;
float one_over_dx2 = 1.0f / (dx * dx);
float one_over_dy2 = 1.0f / (dy * dx);
dim3 dimBlock(32, 32, 1);
dim3 dimGrid((nx - 4 + 31) / 32, (ny - 4 + 31) / 32, n_shots);
dim3 dimBlock_srcrcv(32, 1, 1);
dim3 dimGrid_srcrcv((n_shots + 31) / 32, 1, 1);
for (int y = 0; y < ny; ++y) {
for (int x = 0; x < ny; ++x) {
v2dt2[y * nx + x] = 1500 * 1500 * dt * dt;
}
}
for (int i = 0; i < nt; ++i) {
float t = i * dt - peak_time;
float a = M_PI * freq * t;
a = a * a;
for (int shot = 0; shot < n_shots; ++shot) {
source_amplitudes[i * n_shots + shot] = (1 - 2 * a) * expf(-a);
}
}
for (int shot = 0; shot < n_shots; ++shot) {
source_locations[shot] = shot * ny * nx + 2 * nx + 2;
receiver_locations[shot] = shot * ny * nx + 2 * nx + 2;
}
cudaMalloc(&v2dt2_d, ny * nx * sizeof(float));
cudaMalloc(&wfc_d, n_shots * ny * nx * sizeof(float));
cudaMalloc(&wfp_d, n_shots * ny * nx * sizeof(float));
cudaMalloc(&source_amplitudes_d, n_shots * nt * sizeof(float));
cudaMalloc(&receiver_amplitudes_d, n_shots * nt * sizeof(float));
cudaMalloc(&source_locations_d, n_shots * sizeof(int));
cudaMalloc(&receiver_locations_d, n_shots * sizeof(int));
cudaMemcpy(v2dt2_d, v2dt2, ny * nx * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(source_amplitudes_d, source_amplitudes,
n_shots * nt * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(source_locations_d, source_locations, n_shots * sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(receiver_locations_d, receiver_locations, n_shots * sizeof(float),
cudaMemcpyHostToDevice);
cudaMemset(wfc_d, 0, n_shots * ny * nx * sizeof(float));
cudaMemset(wfp_d, 0, n_shots * ny * nx * sizeof(float));
for (int t = 0; t < nt; ++t) {
step<<<dimGrid, dimBlock>>>(wfc_d, wfp_d, v2dt2_d, one_over_dx2,
one_over_dy2, n_shots, ny, nx);
add_sources<<<dimGrid_srcrcv, dimBlock_srcrcv>>>(
wfp_d, source_amplitudes_d + t * n_shots, source_locations_d, n_shots);
record_receivers<<<dimGrid_srcrcv, dimBlock_srcrcv>>>(
receiver_amplitudes_d + t * n_shots, wfc_d, receiver_locations_d,
n_shots);
std::swap(wfp_d, wfc_d);
}
cudaMemcpy(receiver_amplitudes, receiver_amplitudes_d,
n_shots * nt * sizeof(float), cudaMemcpyDeviceToHost);
float ans = *std::max_element(receiver_amplitudes,
receiver_amplitudes + n_shots * nt);
cudaFree(v2dt2_d);
cudaFree(wfc_d);
cudaFree(wfp_d);
cudaFree(source_amplitudes_d);
cudaFree(receiver_amplitudes_d);
cudaFree(source_locations_d);
cudaFree(receiver_locations_d);
free(v2dt2);
free(source_amplitudes);
free(receiver_amplitudes);
free(source_locations);
free(receiver_locations);
return ans;
}
} // namespace
int main(void) {
float ans = 0.0f;
// warmup
for (int i = 0; i < 3; ++i) {
ans = run();
}
// timing
clock_t t0 = clock();
for (int i = 0; i < 10; ++i) {
run();
}
clock_t t1 = clock();
printf("CUDA: %f %g\n", (float)(t1 - t0) / CLOCKS_PER_SEC, ans);
}
To run:
!g++ -Ofast -march=native wave.cpp
!./a.out
!nvcc wave.cu
!./a.out
!python wave.py
Output:
C++: 2.859534 13.1366
CUDA: 0.272206 13.1366
eager cpu 70.94255131099999 tensor(13.1366)
eager gpu 4.180636458000009 tensor(13.1366, device='cuda:0')
compile cpu 19.22466220399997 tensor(13.1366)
compile gpu 1.8795379320000052 tensor(13.1366, device='cuda:0')
jitscript cpu 23.204399749000004 tensor(13.1366)
jitscript gpu 1.8472471919999975 tensor(13.1366, device='cuda:0')
The fastest PyTorch on the CPU is 19s (torch.compile
), compared to 2.9s for C++.
The fastest PyTorch on the GPU is 1.8s (both torch.compile
and jit.script
), compared to 0.3s for CUDA.