# Example where PyTorch substantially slower than C++/CUDA

In case it interests the developers, here is a notable example (finite difference wave propagation) where PyTorch, `torch.compile`, and `torch.jit.script`, are all substantially (5X+) slower than C++ and CUDA.

PyTorch:

``````import math
import torch
from torch import Tensor
from timeit import timeit

def ricker(freq: float, length: int, dt: float, peak_time: float) -> Tensor:
"""Ricker wavelet for source."""
t = torch.arange(length) * dt - peak_time
a = math.pi**2 * freq**2 * t**2
return (1 - 2 * a) * torch.exp(-a)

def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
one_over_dx2: float, one_over_dy2: float):
"""One finite difference time step of wave equation.

Args:
wfc: current time step wavefield, [shot, x, y]
wfp: previous time step wavefield
v2dt2: velocity^2 * dt^2
one_over_dx2/dy2: 1/dx^2
"""
d2wdx2 = (
-5 / 2 * wfc[:, 2:-2, 2:-2]
+ 4 / 3 * (wfc[:, 3:-1, 2:-2] + wfc[:, 1:-3, 2:-2])
+ -1 / 12 * (wfc[:, 4:, 2:-2] + wfc[:, :-4, 2:-2])
) * one_over_dx2
d2wdy2 = (
-5 / 2 * wfc[:, 2:-2, 2:-2]
+ 4 / 3 * (wfc[:, 2:-2, 3:-1] + wfc[:, 2:-2, 1:-3])
+ -1 / 12 * (wfc[:, 2:-2, 4:] + wfc[:, 2:-2, :-4])
) * one_over_dy2

# wavefield at next time step
v2dt2 * (d2wdx2 + d2wdy2)
+ 2 * wfc[:, 2:-2, 2:-2]
- wfp[:, 2:-2, 2:-2], (2, 2, 2, 2)
)

def closure(step_fn, device):
def run(n_shots: int = 10, ny: int = 200, nx: int = 200,
nt: int = 1000, dt: float = 0.0005,
dy: float = 5, dx: float = 5, freq: float = 25):

# Setup input velocity, source, receiver
v = 1500*torch.ones(ny, nx, device=device)
v2dt2 = v[2:-2, 2:-2]**2 * dt**2
wfc = torch.zeros(n_shots, ny, nx, device=device)
wfp = torch.zeros(n_shots, ny, nx, device=device)
one_over_dx2 = 1 / dx**2
one_over_dy2 = 1 / dy**2
source_amplitudes = (
ricker(freq, nt, dt, 1.5/freq)
.to(device)
.reshape(-1, 1)
.repeat(1, n_shots)
)
sources_batch = torch.arange(n_shots, dtype=torch.long,
device=device)
sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
device=device)
receivers_y = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
receivers_x = 2 * torch.ones(n_shots, dtype=torch.long, device=device)

# Loop over time steps
for t in range(nt-1):
wfn = step_fn(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
wfn[sources_batch,
sources_y,
sources_x] += source_amplitudes[t]
wfp, wfc = wfc, wfn

# Do something with output so compiler doesn't optimise away
return run

def timing(name, run_fn, n_warmup=3):
ans = 0
for _ in range(n_warmup):
ans = run_fn()
print(name, timeit(lambda: run_fn(), number=10), ans)

def eager():
timing('eager cpu', closure(step, torch.device('cpu')))
timing('eager gpu', closure(step, torch.device('cuda:0')))

def torchcompile():
timing('compile cpu', closure(torch.compile(step, fullgraph=True),
torch.device('cpu')))
timing('compile gpu', closure(torch.compile(step, fullgraph=True),
torch.device('cuda:0')))

def jitscript():
timing('jitscript cpu', closure(torch.jit.script(step),
torch.device('cpu')))
timing('jitscript gpu', closure(torch.jit.script(step),
torch.device('cuda:0')))

eager()
torchcompile()
jitscript()
``````

C++:

``````#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#include <algorithm>

namespace {

void step(float const *__restrict const wfc, float *__restrict const wfp,
float const *__restrict const v2dt2, float const one_over_dx2,
float const one_over_dy2, int const n_shots, int const ny,
int const nx) {
for (int shot = 0; shot < n_shots; ++shot) {
for (int y = 2; y < ny - 2; ++y) {
for (int x = 2; x < nx - 2; ++x) {
float d2wdx2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + (y + 1) * nx + x] +
wfc[shot * ny * nx + (y - 1) * nx + x]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + (y + 2) * nx + x] +
wfc[shot * ny * nx + (y - 2) * nx + x])) *
one_over_dx2;
float d2wdy2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + y * nx + x + 1] +
wfc[shot * ny * nx + y * nx + x - 1]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + y * nx + x + 2] +
wfc[shot * ny * nx + y * nx + x - 2])) *
one_over_dy2;

wfp[shot * ny * nx + y * nx + x] =
(v2dt2[y * nx + x] * (d2wdx2 + d2wdy2) +
2 * wfc[shot * ny * nx + y * nx + x] -
wfp[shot * ny * nx + y * nx + x]);
}
}
}
}

void add_sources(float *__restrict wf, float const *__restrict f,
int const *__restrict sources_i, int n_shots) {
for (int shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
wf[sources_i[shot_idx]] += f[shot_idx];
}
}

void record_receivers(float *__restrict r, float const *__restrict wf,
int const *__restrict receivers_i, int n_shots) {
for (int shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
}
}

float run() {
int n_shots = 10;
int ny = 200;
int nx = 200;
int nt = 1000;
float dt = 0.0005f;
float dy = 5.0f;
float dx = 5.0f;
float freq = 25.0f;
float peak_time = 1.5f / freq;
float *v2dt2 = (float *)malloc(ny * nx * sizeof(float));
float *wfc = (float *)calloc(n_shots * ny * nx, sizeof(float));
float *wfp = (float *)calloc(n_shots * ny * nx, sizeof(float));
float *source_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
float *receiver_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
int *source_locations = (int *)malloc(n_shots * sizeof(int));
int *receiver_locations = (int *)malloc(n_shots * sizeof(int));
float one_over_dx2 = 1.0f / (dx * dx);
float one_over_dy2 = 1.0f / (dy * dx);

for (int y = 0; y < ny; ++y) {
for (int x = 0; x < ny; ++x) {
v2dt2[y * nx + x] = 1500 * 1500 * dt * dt;
}
}
for (int i = 0; i < nt; ++i) {
float t = i * dt - peak_time;
float a = M_PI * freq * t;
a = a * a;
for (int shot = 0; shot < n_shots; ++shot) {
source_amplitudes[i * n_shots + shot] = (1 - 2 * a) * expf(-a);
}
}
for (int shot = 0; shot < n_shots; ++shot) {
source_locations[shot] = shot * ny * nx + 2 * nx + 2;
receiver_locations[shot] = shot * ny * nx + 2 * nx + 2;
}

for (int t = 0; t < nt; ++t) {
step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots, ny, nx);
add_sources(wfp, source_amplitudes + t * n_shots, source_locations,
n_shots);
n_shots);
std::swap(wfc, wfp);
}

free(v2dt2);
free(wfc);
free(wfp);
free(source_amplitudes);
free(source_locations);

return ans;
}

}  // namespace

int main(void) {
float ans = 0.0f;

// warmup
for (int i = 0; i < 3; ++i) {
ans = run();
}

// timing
clock_t t0 = clock();
for (int i = 0; i < 10; ++i) {
run();
}
clock_t t1 = clock();
printf("C++: %f %g\n", (float)(t1 - t0) / CLOCKS_PER_SEC, ans);
}
``````

CUDA:

``````#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#include <algorithm>

namespace {

__global__ void step(float const *__restrict const wfc,
float *__restrict const wfp,
float const *__restrict const v2dt2,
float const one_over_dx2, float const one_over_dy2,
int const n_shots, int const ny, int const nx) {
auto x{blockIdx.x * blockDim.x + threadIdx.x + 2};
auto y{blockIdx.y * blockDim.y + threadIdx.y + 2};
auto shot{blockIdx.z * blockDim.z + threadIdx.z};
if (y < ny - 2 and x < nx - 2) {
float d2wdx2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + (y + 1) * nx + x] +
wfc[shot * ny * nx + (y - 1) * nx + x]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + (y + 2) * nx + x] +
wfc[shot * ny * nx + (y - 2) * nx + x])) *
one_over_dx2;
float d2wdy2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
4.0f / 3.0f *
(wfc[shot * ny * nx + y * nx + x + 1] +
wfc[shot * ny * nx + y * nx + x - 1]) +
-1.0f / 12.0f *
(wfc[shot * ny * nx + y * nx + x + 2] +
wfc[shot * ny * nx + y * nx + x - 2])) *
one_over_dy2;

wfp[shot * ny * nx + y * nx + x] = (v2dt2[y * nx + x] * (d2wdx2 + d2wdy2) +
2 * wfc[shot * ny * nx + y * nx + x] -
wfp[shot * ny * nx + y * nx + x]);
}
}

__global__ void add_sources(float *__restrict wf, float const *__restrict f,
int const *__restrict sources_i, int n_shots) {
auto shot_idx{blockIdx.x * blockDim.x + threadIdx.x};
if (shot_idx < n_shots) {
wf[sources_i[shot_idx]] += f[shot_idx];
}
}

float const *__restrict wf,
int n_shots) {
auto shot_idx{blockIdx.x * blockDim.x + threadIdx.x};
if (shot_idx < n_shots) {
}
}

float run() {
int n_shots = 10;
int ny = 200;
int nx = 200;
int nt = 1000;
float dt = 0.0005f;
float dy = 5.0f;
float dx = 5.0f;
float freq = 25.0f;
float peak_time = 1.5f / freq;
float *v2dt2 = (float *)malloc(ny * nx * sizeof(float));
float *v2dt2_d;
float *wfc_d;
float *wfp_d;
float *source_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
float *source_amplitudes_d;
float *receiver_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
int *source_locations = (int *)malloc(n_shots * sizeof(int));
int *source_locations_d;
int *receiver_locations = (int *)malloc(n_shots * sizeof(int));
float one_over_dx2 = 1.0f / (dx * dx);
float one_over_dy2 = 1.0f / (dy * dx);
dim3 dimBlock(32, 32, 1);
dim3 dimGrid((nx - 4 + 31) / 32, (ny - 4 + 31) / 32, n_shots);
dim3 dimBlock_srcrcv(32, 1, 1);
dim3 dimGrid_srcrcv((n_shots + 31) / 32, 1, 1);

for (int y = 0; y < ny; ++y) {
for (int x = 0; x < ny; ++x) {
v2dt2[y * nx + x] = 1500 * 1500 * dt * dt;
}
}
for (int i = 0; i < nt; ++i) {
float t = i * dt - peak_time;
float a = M_PI * freq * t;
a = a * a;
for (int shot = 0; shot < n_shots; ++shot) {
source_amplitudes[i * n_shots + shot] = (1 - 2 * a) * expf(-a);
}
}
for (int shot = 0; shot < n_shots; ++shot) {
source_locations[shot] = shot * ny * nx + 2 * nx + 2;
receiver_locations[shot] = shot * ny * nx + 2 * nx + 2;
}

cudaMalloc(&v2dt2_d, ny * nx * sizeof(float));
cudaMalloc(&wfc_d, n_shots * ny * nx * sizeof(float));
cudaMalloc(&wfp_d, n_shots * ny * nx * sizeof(float));
cudaMalloc(&source_amplitudes_d, n_shots * nt * sizeof(float));
cudaMalloc(&receiver_amplitudes_d, n_shots * nt * sizeof(float));
cudaMalloc(&source_locations_d, n_shots * sizeof(int));

cudaMemcpy(v2dt2_d, v2dt2, ny * nx * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(source_amplitudes_d, source_amplitudes,
n_shots * nt * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(source_locations_d, source_locations, n_shots * sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
cudaMemset(wfc_d, 0, n_shots * ny * nx * sizeof(float));
cudaMemset(wfp_d, 0, n_shots * ny * nx * sizeof(float));

for (int t = 0; t < nt; ++t) {
step<<<dimGrid, dimBlock>>>(wfc_d, wfp_d, v2dt2_d, one_over_dx2,
one_over_dy2, n_shots, ny, nx);
wfp_d, source_amplitudes_d + t * n_shots, source_locations_d, n_shots);
n_shots);
std::swap(wfp_d, wfc_d);
}

n_shots * nt * sizeof(float), cudaMemcpyDeviceToHost);

cudaFree(v2dt2_d);
cudaFree(wfc_d);
cudaFree(wfp_d);
cudaFree(source_amplitudes_d);
cudaFree(source_locations_d);
free(v2dt2);
free(source_amplitudes);
free(source_locations);

return ans;
}

}  // namespace

int main(void) {
float ans = 0.0f;

// warmup
for (int i = 0; i < 3; ++i) {
ans = run();
}

// timing
clock_t t0 = clock();
for (int i = 0; i < 10; ++i) {
run();
}
clock_t t1 = clock();
printf("CUDA: %f %g\n", (float)(t1 - t0) / CLOCKS_PER_SEC, ans);
}
``````

To run:

``````!g++ -Ofast -march=native wave.cpp
!./a.out
!nvcc wave.cu
!./a.out
!python wave.py
``````

Output:

``````C++: 2.859534 13.1366
CUDA: 0.272206 13.1366
eager cpu 70.94255131099999 tensor(13.1366)
eager gpu 4.180636458000009 tensor(13.1366, device='cuda:0')
compile cpu 19.22466220399997 tensor(13.1366)
compile gpu 1.8795379320000052 tensor(13.1366, device='cuda:0')
jitscript cpu 23.204399749000004 tensor(13.1366)
jitscript gpu 1.8472471919999975 tensor(13.1366, device='cuda:0')
``````

The fastest PyTorch on the CPU is 19s (`torch.compile`), compared to 2.9s for C++.
The fastest PyTorch on the GPU is 1.8s (both `torch.compile` and `jit.script`), compared to 0.3s for CUDA.

Your profiling on the GPU is I valid since you are not synchronizing the code (unless Iām missing it somewhere). Besides that you could try to use CUDA Graphs in PyTorcj to reduce the dispatching and kernel launch overhead as it seems these take the majority of the time while the actual workload is not dominating the timeline.

1 Like

I do not believe that explicit synchronization is necessary as the blocking calls (`cudaMemcpy`, `cudaFree`, etc.) will cause implicit synchronization. Nevertheless, I added a call to `cudaDeviceSynchronize()` at the end of the `run` function and re-ran the timing. As expected, the results were similar.

I modified the Python code to use graphs. I used two approaches.

The first (the new `Run` class) captures a graph of the entire loop over time steps in the constructor. The `run` method then only needs to replay the graph. Note that this method has a small unfair advantage over the C++/CUDA and other Python approaches, as it only allocates memory and performs setup when the object is created, rather than for every run.

The second (`closure_step`) only creates a graph that captures two time steps. The loop over time steps is thus only performed `(nt-1)//2` times, each replaying this graph. In order for the inputs and outputs of the graph to be the same memory address each time, I copy the source amplitudes for those two time steps into a tensor `source_amplitudes_two`, which is what is used inside the graph, and do something similar for the recorded receiver data for those two time steps.

``````import math
import torch
from torch import Tensor
from timeit import timeit

def ricker(freq: float, length: int, dt: float, peak_time: float) -> Tensor:
"""Ricker wavelet for source."""
t = torch.arange(length) * dt - peak_time
a = math.pi**2 * freq**2 * t**2
return (1 - 2 * a) * torch.exp(-a)

def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
one_over_dx2: float, one_over_dy2: float):
"""One finite difference time step of wave equation.

Args:
wfc: current time step wavefield, [shot, x, y]
wfp: previous time step wavefield
v2dt2: velocity^2 * dt^2
one_over_dx2/dy2: 1/dx^2
"""
d2wdx2 = (
-5 / 2 * wfc[:, 2:-2, 2:-2]
+ 4 / 3 * (wfc[:, 3:-1, 2:-2] + wfc[:, 1:-3, 2:-2])
+ -1 / 12 * (wfc[:, 4:, 2:-2] + wfc[:, :-4, 2:-2])
) * one_over_dx2
d2wdy2 = (
-5 / 2 * wfc[:, 2:-2, 2:-2]
+ 4 / 3 * (wfc[:, 2:-2, 3:-1] + wfc[:, 2:-2, 1:-3])
+ -1 / 12 * (wfc[:, 2:-2, 4:] + wfc[:, 2:-2, :-4])
) * one_over_dy2

# wavefield at next time step
wfp[:, 2:-2, 2:-2] = (
v2dt2 * (d2wdx2 + d2wdy2)
+ 2 * wfc[:, 2:-2, 2:-2]
- wfp[:, 2:-2, 2:-2]
)
return wfp, wfc

class Run():
def __init__(self, step_fn, device,
n_shots: int = 10, ny: int = 200, nx: int = 200,
nt: int = 1000, dt: float = 0.0005,
dy: float = 5, dx: float = 5, freq: float = 25):

# Setup input velocity, source, receiver
v = 1500*torch.ones(ny, nx, device=device)
self.v2dt2 = v[2:-2, 2:-2]**2 * dt**2
self.wfc = torch.zeros(n_shots, ny, nx, device=device)
self.wfp = torch.zeros(n_shots, ny, nx, device=device)
self.one_over_dx2 = 1 / dx**2
self.one_over_dy2 = 1 / dy**2
self.nt = nt
self.step_fn = step_fn
self.source_amplitudes = (
ricker(freq, nt, dt, 1.5/freq)
.to(device)
.reshape(-1, 1)
.repeat(1, n_shots)
)
self.sources_batch = torch.arange(n_shots, dtype=torch.long,
device=device)
self.sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
self.sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
device=device)
self.receivers_y = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
self.receivers_x = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
# Warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.loop()
torch.cuda.current_stream().wait_stream(s)

# Capture graph
self.g = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.g):
self.loop()

def loop(self):
self.wfc.fill_(0)
self.wfp.fill_(0)
wfc = self.wfc
wfp = self.wfp
# Loop over time steps
for t in range(self.nt-1):
wfc, wfp = self.step_fn(wfc, wfp, self.v2dt2,
self.one_over_dx2, self.one_over_dy2)
wfc[self.sources_batch,
self.sources_y,
self.sources_x] += self.source_amplitudes[t]

def run(self):
self.g.replay()

def closure_step(step_fn, device):
def run(n_shots: int = 10, ny: int = 200, nx: int = 200,
nt: int = 1000, dt: float = 0.0005,
dy: float = 5, dx: float = 5, freq: float = 25):

# Setup input velocity, source, receiver
v = 1500*torch.ones(ny, nx, device=device)
v2dt2 = v[2:-2, 2:-2]**2 * dt**2
wfc = torch.zeros(n_shots, ny, nx, device=device)
wfp = torch.zeros(n_shots, ny, nx, device=device)
one_over_dx2 = 1 / dx**2
one_over_dy2 = 1 / dy**2
source_amplitudes = (
ricker(freq, nt, dt, 1.5/freq)
.to(device)
.reshape(-1, 1)
.repeat(1, n_shots)
)
sources_batch = torch.arange(n_shots, dtype=torch.long,
device=device)
sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
device=device)
device=device)
receivers_y = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
receivers_x = 2 * torch.ones(n_shots, dtype=torch.long, device=device)

source_amplitudes_two = torch.zeros(2, n_shots, device=device)

def two_steps(wfc, wfp, source_amplitudes_two,
for i in range(2):
wfc, wfp = step_fn(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
wfc[sources_batch,
sources_y,
sources_x] += source_amplitudes_two[i]

# Warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
torch.cuda.current_stream().wait_stream(s)

# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):

# Loop over graph
wfc.fill_(0)
wfp.fill_(0)
for t in range(0, nt-1, 2):
source_amplitudes_two[:] = source_amplitudes[t:t+2]
g.replay()

# Do something with output so compiler doesn't optimise away
return run

def timing(name, run_fn, n_warmup=3):
ans = 0
for _ in range(n_warmup):
ans = run_fn()
print(name, timeit(lambda: run_fn(), number=10), ans)

def eager():
timing('eager', Run(step, torch.device('cuda:0')).run)
timing('eager step', closure_step(step, torch.device('cuda:0')))

def torchcompile():
timing('compile', Run(torch.compile(step, fullgraph=True),
torch.device('cuda:0')).run)
timing('compile step',
closure_step(torch.compile(step, fullgraph=True),
torch.device('cuda:0')))

def jitscript():
timing('jitscript', Run(torch.jit.script(step),
torch.device('cuda:0')).run)
timing('jitscript step', closure_step(torch.jit.script(step),
torch.device('cuda:0')))

eager()
torchcompile()
jitscript()
``````

This time, the results were:

``````CUDA (original): 0.201396 13.1366
CUDA (extra synchronization): 0.199689 13.1366
eager 2.4772268719998465 tensor(13.1366, device='cuda:0')
eager step 1.6758502800003043 tensor(13.1366, device='cuda:0')
compile 0.592774092000127 tensor(13.1366, device='cuda:0')
compile step 0.9349224309999045 tensor(13.1366, device='cuda:0')
jitscript 0.707418725000025 tensor(13.1366, device='cuda:0')
jitscript step 1.0656032869997034 tensor(13.1366, device='cuda:0')
``````

(` step` refers to the second approach, where only two steps are captured by the graph)

While somewhat improved, there is unfortunately still a substantial (about 3X) performance difference between my hand-coded but not substantially optimized CUDA and the fastest PyTorch approach (graph of the full loop, using `torch.compile` on the `step` function).

1 Like

To isolate the computation more clearly, here is a modification that just performs the time stepping (no sources/receivers). I have written three versions - one for torch.compile, one directly in Triton, and one in CUDA. They are written in what I think is the most natural way for each approach, which results in small differences between them (the torch.compile returns its result in a new Tensor, while the other two overwrite the previous wavefield, for example), but I hope you will agree that the results are still reasonable measures of the performance of each approach.

Timings:

• torch.compile: 0.13s
• Triton: 0.1s
• CUDA: 0.08s

It seems that in this case torch.compile works pretty well. I will continue to explore. I know this is not a typical deep learning application, and so it is understandable that it is not a priority for PyTorch, but I think there are many scientific and engineering applications that are similar and would benefit from being able to get good forward, and in some cases (such as mine) backpropagation, performance, using PyTorch.

torch.compile:

``````import torch
from torch import Tensor

def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
one_over_dx2: float, one_over_dy2: float) -> Tensor:
wfcp = torch.nn.functional.pad(wfc, (2, 2, 2, 2))
return (
v2dt2[None] * (
(
-5 / 2 * wfc
+ 4 / 3 * (wfcp[:, 3:-1, 2:-2] + wfcp[:, 1:-3, 2:-2])
+ -1 / 12 * (wfcp[:, 4:, 2:-2] + wfcp[:, :-4, 2:-2])
) * one_over_dx2 +
(
-5 / 2 * wfc
+ 4 / 3 * (wfcp[:, 2:-2, 3:-1] + wfcp[:, 2:-2, 1:-3])
+ -1 / 12 * (wfcp[:, 2:-2, 4:] + wfcp[:, 2:-2, :-4])
) * one_over_dy2
)
+ 2 * wfc
- wfp
)

step = torch.compile(step)

n_shots = 50
ny = 500
nx = 500
nt = 100
torch.manual_seed(1)
v2dt2 = torch.rand(ny, nx, device=torch.device('cuda')) - 0.5
wfc = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
wfp = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
one_over_dy2 = one_over_dx2 = 1 / 5**2

for _ in range(nt):
wfn = step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
wfc, wfp = wfn, wfc

import time
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nt):
wfn = step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
wfc, wfp = wfn, wfc
torch.cuda.synchronize()
t1 = time.time()
print(t1 - t0)
``````

Triton:

``````import torch
import triton
import triton.language as tl

@triton.jit
def step(wfc_ptr, wfp_ptr, v2dt2_ptr, one_over_dx2, one_over_dy2,
n_shots, ny, nx,
BLOCKSIZE_Y: tl.constexpr, BLOCKSIZE_X: tl.constexpr):
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
pid_shot = tl.program_id(axis=2)
y_start = pid_y * BLOCKSIZE_Y
x_start = pid_x * BLOCKSIZE_X
stride_shot = ny * nx
offs_y = y_start + tl.arange(0, BLOCKSIZE_Y)
offs_x = x_start + tl.arange(0, BLOCKSIZE_X)
offs = pid_shot * stride_shot + offs_y[:, None] * nx + offs_x[None, :]
(offs_x[None, :] < nx),
other=0.0)
wfc_up = tl.load(wfc_ptr + offs - nx,
mask=(0 <= offs_y[:, None] - 1) &
(offs_y[:, None] - 1 < ny) &
(offs_x[None, :] < nx),
other=0.0)
wfc_up2 = tl.load(wfc_ptr + offs - 2*nx,
mask=(0 <= offs_y[:, None] - 2) &
(offs_y[:, None] - 2 < ny) &
(offs_x[None, :] < nx),
other=0.0)
wfc_down = tl.load(wfc_ptr + offs + nx,
mask=(offs_y[:, None] + 1 < ny) &
(offs_x[None, :] < nx),
other=0.0)
wfc_down2 = tl.load(wfc_ptr + offs + 2*nx,
mask=(offs_y[:, None] + 2 < ny) &
(offs_x[None, :] < nx),
other=0.0)
wfc_left = tl.load(wfc_ptr + offs - 1,
(0 <= offs_x[None, :] - 1) &
(offs_x[None, :] - 1 < nx),
other=0.0)
wfc_left2 = tl.load(wfc_ptr + offs - 2,
(0 <= offs_x[None, :] - 2) &
(offs_x[None, :] - 2 < nx),
other=0.0)
wfc_right = tl.load(wfc_ptr + offs + 1,
(offs_x[None, :] + 1 < nx),
other=0.0)
wfc_right2 = tl.load(wfc_ptr + offs + 2,
(offs_x[None, :] + 2 < nx),
other=0.0)

(offs_x[None, :] < nx),
other=0.0)
v2dt2 = tl.load(v2dt2_ptr + offs_y[:, None] * nx + offs_x[None, :],
(offs_x[None, :] < nx),
other=0.0)

tl.store(wfp_ptr + offs,
v2dt2 * (
(
-5 / 2 * wfc_c
+ 4 / 3 * (wfc_up + wfc_down)
+ -1 / 12 * (wfc_up2 + wfc_down2)
) * one_over_dx2 +
(
-5 / 2 * wfc_c
+ 4 / 3 * (wfc_left + wfc_right)
+ -1 / 12 * (wfc_left2 + wfc_right2)
) * one_over_dy2
)
+ 2 * wfc_c
- wfp,
mask=(offs_y[:, None] < ny) & (offs_x[None, :] < nx))

n_shots = 50
ny = 500
nx = 500
nt = 100
torch.manual_seed(1)
v2dt2 = torch.rand(ny, nx, device=torch.device('cuda')) - 0.5
wfc = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
wfp = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
one_over_dy2 = one_over_dx2 = 1 / 5**2

grid = (triton.cdiv(nx, 32), triton.cdiv(ny, 32), n_shots)

for _ in range(nt):
step[grid](wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots,
ny, nx, BLOCKSIZE_Y=32, BLOCKSIZE_X=32)
wfc, wfp = wfp, wfc

import time
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nt):
step[grid](wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots,
ny, nx, BLOCKSIZE_Y=32, BLOCKSIZE_X=32)
wfc, wfp = wfp, wfc
torch.cuda.synchronize()
t1 = time.time()
print(t1 - t0)
``````

CUDA:

``````#include <stdio.h>
#include <stdlib.h>
#include <time.h>

namespace {

__global__ void step(float const *__restrict const wfc,
float *__restrict const wfp,
float const *__restrict const v2dt2,
float const one_over_dx2, float const one_over_dy2,
int const n_shots, int const ny, int const nx) {
int x = blockIdx.x * blockDim.x + threadIdx.x + 2;
int y = blockIdx.y * blockDim.y + threadIdx.y + 2;
int shot = blockIdx.z * blockDim.z + threadIdx.z;
int i = shot * ny * nx + y * nx + x;
if (y < ny - 2 && x < nx - 2) {
wfp[i] =
v2dt2[y * nx + x] *
((-5.0f / 2.0f * wfc[i] +
4.0f / 3.0f * (wfc[i + nx] + wfc[i - nx]) +
-1.0f / 12.0f * (wfc[i + 2 * nx] + wfc[i - 2 * nx])) *
one_over_dx2 +
(-5.0f / 2.0f * wfc[i] + 4.0f / 3.0f * (wfc[i + 1] + wfc[i - 1]) +
-1.0f / 12.0f * (wfc[i + 2] + wfc[i - 2])) *
one_over_dy2) +
2 * wfc[i] - wfp[i];
}
}

}  // namespace

int main() {
int n_shots = 50;
int ny = 504;
int nx = 504;
int nt = 100;
float *v2dt2 = (float *)calloc(ny * nx, sizeof(float));
float *v2dt2_d;
float *wfc = (float *)calloc(n_shots * ny * nx, sizeof(float));
float *wfc_d;
float *wfp = (float *)calloc(n_shots * ny * nx, sizeof(float));
float *wfp_d;
float one_over_dx2 = 1.0f / 25.0f;
float one_over_dy2 = 1.0f / 25.0f;
dim3 dimBlock(32, 32, 1);
dim3 dimGrid((nx - 4 + 31) / 32, (ny - 4 + 31) / 32, n_shots);

srand(1);
for (int y = 2; y < ny - 2; ++y) {
for (int x = 2; x < ny - 2; ++x) {
v2dt2[y * nx + x] = (float)rand() / RAND_MAX - 0.5f;
}
}
for (int shot = 0; shot < n_shots; ++shot) {
for (int y = 2; y < ny - 2; ++y) {
for (int x = 2; x < ny - 2; ++x) {
wfc[shot * ny * nx + y * nx + x] = (float)rand() / RAND_MAX - 0.5f;
}
}
}
for (int shot = 0; shot < n_shots; ++shot) {
for (int y = 2; y < ny - 2; ++y) {
for (int x = 2; x < ny - 2; ++x) {
wfp[shot * ny * nx + y * nx + x] = (float)rand() / RAND_MAX - 0.5f;
}
}
}

cudaMalloc(&v2dt2_d, ny * nx * sizeof(float));
cudaMalloc(&wfc_d, n_shots * ny * nx * sizeof(float));
cudaMalloc(&wfp_d, n_shots * ny * nx * sizeof(float));

cudaMemcpy(v2dt2_d, v2dt2, ny * nx * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(wfc_d, wfc, n_shots * ny * nx * sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(wfp_d, wfp, n_shots * ny * nx * sizeof(float),
cudaMemcpyHostToDevice);

clock_t t0 = clock();
for (int t = 0; t < nt; ++t) {
step<<<dimGrid, dimBlock>>>(wfc_d, wfp_d, v2dt2_d, one_over_dx2,
one_over_dy2, n_shots, ny, nx);
float *tmp = wfc_d;
wfc_d = wfp_d;
wfp_d = tmp;
}