Example where PyTorch substantially slower than C++/CUDA

In case it interests the developers, here is a notable example (finite difference wave propagation) where PyTorch, torch.compile, and torch.jit.script, are all substantially (5X+) slower than C++ and CUDA.

PyTorch:

import math
import torch
from torch import Tensor
from timeit import timeit


def ricker(freq: float, length: int, dt: float, peak_time: float) -> Tensor:
    """Ricker wavelet for source."""
    t = torch.arange(length) * dt - peak_time
    a = math.pi**2 * freq**2 * t**2
    return (1 - 2 * a) * torch.exp(-a)


def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
         one_over_dx2: float, one_over_dy2: float):
    """One finite difference time step of wave equation.

    Args:
       wfc: current time step wavefield, [shot, x, y]
       wfp: previous time step wavefield
       v2dt2: velocity^2 * dt^2
       one_over_dx2/dy2: 1/dx^2
    """
    d2wdx2 = (
        -5 / 2 * wfc[:, 2:-2, 2:-2]
        + 4 / 3 * (wfc[:, 3:-1, 2:-2] + wfc[:, 1:-3, 2:-2])
        + -1 / 12 * (wfc[:, 4:, 2:-2] + wfc[:, :-4, 2:-2])
    ) * one_over_dx2
    d2wdy2 = (
        -5 / 2 * wfc[:, 2:-2, 2:-2]
        + 4 / 3 * (wfc[:, 2:-2, 3:-1] + wfc[:, 2:-2, 1:-3])
        + -1 / 12 * (wfc[:, 2:-2, 4:] + wfc[:, 2:-2, :-4])
    ) * one_over_dy2

    # wavefield at next time step
    return torch.nn.functional.pad(
        v2dt2 * (d2wdx2 + d2wdy2)
        + 2 * wfc[:, 2:-2, 2:-2]
        - wfp[:, 2:-2, 2:-2], (2, 2, 2, 2)
    )


def closure(step_fn, device):
    def run(n_shots: int = 10, ny: int = 200, nx: int = 200,
            nt: int = 1000, dt: float = 0.0005,
            dy: float = 5, dx: float = 5, freq: float = 25):

        # Setup input velocity, source, receiver
        v = 1500*torch.ones(ny, nx, device=device)
        v2dt2 = v[2:-2, 2:-2]**2 * dt**2
        wfc = torch.zeros(n_shots, ny, nx, device=device)
        wfp = torch.zeros(n_shots, ny, nx, device=device)
        one_over_dx2 = 1 / dx**2
        one_over_dy2 = 1 / dy**2
        source_amplitudes = (
            ricker(freq, nt, dt, 1.5/freq)
            .to(device)
            .reshape(-1, 1)
            .repeat(1, n_shots)
        )
        sources_batch = torch.arange(n_shots, dtype=torch.long,
                                     device=device)
        sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
                                   device=device)
        sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
                                   device=device)
        receivers_batch = torch.arange(n_shots, dtype=torch.long,
                                       device=device)
        receivers_y = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
        receivers_x = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
        receiver_amplitudes = torch.zeros(nt, n_shots, device=device)

        # Loop over time steps
        for t in range(nt-1):
            wfn = step_fn(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
            wfn[sources_batch,
                sources_y,
                sources_x] += source_amplitudes[t]
            receiver_amplitudes[t] = \
                wfc[receivers_batch, receivers_x, receivers_y]
            wfp, wfc = wfc, wfn

        # Do something with output so compiler doesn't optimise away
        return receiver_amplitudes.max()
    return run


def timing(name, run_fn, n_warmup=3):
    ans = 0
    for _ in range(n_warmup):
        ans = run_fn()
    print(name, timeit(lambda: run_fn(), number=10), ans)


def eager():
    with torch.no_grad():
        timing('eager cpu', closure(step, torch.device('cpu')))
        timing('eager gpu', closure(step, torch.device('cuda:0')))


def torchcompile():
    with torch.no_grad():
        timing('compile cpu', closure(torch.compile(step, fullgraph=True),
                                      torch.device('cpu')))
        timing('compile gpu', closure(torch.compile(step, fullgraph=True),
                                      torch.device('cuda:0')))


def jitscript():
    with torch.no_grad():
        timing('jitscript cpu', closure(torch.jit.script(step),
                                        torch.device('cpu')))
        timing('jitscript gpu', closure(torch.jit.script(step),
                                        torch.device('cuda:0')))


eager()
torchcompile()
jitscript()

C++:

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#include <algorithm>

namespace {

void step(float const *__restrict const wfc, float *__restrict const wfp,
          float const *__restrict const v2dt2, float const one_over_dx2,
          float const one_over_dy2, int const n_shots, int const ny,
          int const nx) {
  for (int shot = 0; shot < n_shots; ++shot) {
    for (int y = 2; y < ny - 2; ++y) {
      for (int x = 2; x < nx - 2; ++x) {
        float d2wdx2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
                        4.0f / 3.0f *
                            (wfc[shot * ny * nx + (y + 1) * nx + x] +
                             wfc[shot * ny * nx + (y - 1) * nx + x]) +
                        -1.0f / 12.0f *
                            (wfc[shot * ny * nx + (y + 2) * nx + x] +
                             wfc[shot * ny * nx + (y - 2) * nx + x])) *
                       one_over_dx2;
        float d2wdy2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
                        4.0f / 3.0f *
                            (wfc[shot * ny * nx + y * nx + x + 1] +
                             wfc[shot * ny * nx + y * nx + x - 1]) +
                        -1.0f / 12.0f *
                            (wfc[shot * ny * nx + y * nx + x + 2] +
                             wfc[shot * ny * nx + y * nx + x - 2])) *
                       one_over_dy2;

        wfp[shot * ny * nx + y * nx + x] =
            (v2dt2[y * nx + x] * (d2wdx2 + d2wdy2) +
             2 * wfc[shot * ny * nx + y * nx + x] -
             wfp[shot * ny * nx + y * nx + x]);
      }
    }
  }
}

void add_sources(float *__restrict wf, float const *__restrict f,
                 int const *__restrict sources_i, int n_shots) {
  for (int shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
    wf[sources_i[shot_idx]] += f[shot_idx];
  }
}

void record_receivers(float *__restrict r, float const *__restrict wf,
                      int const *__restrict receivers_i, int n_shots) {
  for (int shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
    r[shot_idx] = wf[receivers_i[shot_idx]];
  }
}

float run() {
  int n_shots = 10;
  int ny = 200;
  int nx = 200;
  int nt = 1000;
  float dt = 0.0005f;
  float dy = 5.0f;
  float dx = 5.0f;
  float freq = 25.0f;
  float peak_time = 1.5f / freq;
  float *v2dt2 = (float *)malloc(ny * nx * sizeof(float));
  float *wfc = (float *)calloc(n_shots * ny * nx, sizeof(float));
  float *wfp = (float *)calloc(n_shots * ny * nx, sizeof(float));
  float *source_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
  float *receiver_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
  int *source_locations = (int *)malloc(n_shots * sizeof(int));
  int *receiver_locations = (int *)malloc(n_shots * sizeof(int));
  float one_over_dx2 = 1.0f / (dx * dx);
  float one_over_dy2 = 1.0f / (dy * dx);

  for (int y = 0; y < ny; ++y) {
    for (int x = 0; x < ny; ++x) {
      v2dt2[y * nx + x] = 1500 * 1500 * dt * dt;
    }
  }
  for (int i = 0; i < nt; ++i) {
    float t = i * dt - peak_time;
    float a = M_PI * freq * t;
    a = a * a;
    for (int shot = 0; shot < n_shots; ++shot) {
      source_amplitudes[i * n_shots + shot] = (1 - 2 * a) * expf(-a);
    }
  }
  for (int shot = 0; shot < n_shots; ++shot) {
    source_locations[shot] = shot * ny * nx + 2 * nx + 2;
    receiver_locations[shot] = shot * ny * nx + 2 * nx + 2;
  }

  for (int t = 0; t < nt; ++t) {
    step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots, ny, nx);
    add_sources(wfp, source_amplitudes + t * n_shots, source_locations,
                n_shots);
    record_receivers(receiver_amplitudes + t * n_shots, wfc, receiver_locations,
                     n_shots);
    std::swap(wfc, wfp);
  }

  float ans = *std::max_element(receiver_amplitudes,
                                receiver_amplitudes + n_shots * nt);

  free(v2dt2);
  free(wfc);
  free(wfp);
  free(source_amplitudes);
  free(receiver_amplitudes);
  free(source_locations);
  free(receiver_locations);

  return ans;
}

}  // namespace

int main(void) {
  float ans = 0.0f;

  // warmup
  for (int i = 0; i < 3; ++i) {
    ans = run();
  }

  // timing
  clock_t t0 = clock();
  for (int i = 0; i < 10; ++i) {
    run();
  }
  clock_t t1 = clock();
  printf("C++: %f %g\n", (float)(t1 - t0) / CLOCKS_PER_SEC, ans);
}

CUDA:

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#include <algorithm>

namespace {

__global__ void step(float const *__restrict const wfc,
                     float *__restrict const wfp,
                     float const *__restrict const v2dt2,
                     float const one_over_dx2, float const one_over_dy2,
                     int const n_shots, int const ny, int const nx) {
  auto x{blockIdx.x * blockDim.x + threadIdx.x + 2};
  auto y{blockIdx.y * blockDim.y + threadIdx.y + 2};
  auto shot{blockIdx.z * blockDim.z + threadIdx.z};
  if (y < ny - 2 and x < nx - 2) {
    float d2wdx2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
                    4.0f / 3.0f *
                        (wfc[shot * ny * nx + (y + 1) * nx + x] +
                         wfc[shot * ny * nx + (y - 1) * nx + x]) +
                    -1.0f / 12.0f *
                        (wfc[shot * ny * nx + (y + 2) * nx + x] +
                         wfc[shot * ny * nx + (y - 2) * nx + x])) *
                   one_over_dx2;
    float d2wdy2 = (-5.0f / 2.0f * wfc[shot * ny * nx + y * nx + x] +
                    4.0f / 3.0f *
                        (wfc[shot * ny * nx + y * nx + x + 1] +
                         wfc[shot * ny * nx + y * nx + x - 1]) +
                    -1.0f / 12.0f *
                        (wfc[shot * ny * nx + y * nx + x + 2] +
                         wfc[shot * ny * nx + y * nx + x - 2])) *
                   one_over_dy2;

    wfp[shot * ny * nx + y * nx + x] = (v2dt2[y * nx + x] * (d2wdx2 + d2wdy2) +
                                        2 * wfc[shot * ny * nx + y * nx + x] -
                                        wfp[shot * ny * nx + y * nx + x]);
  }
}

__global__ void add_sources(float *__restrict wf, float const *__restrict f,
                            int const *__restrict sources_i, int n_shots) {
  auto shot_idx{blockIdx.x * blockDim.x + threadIdx.x};
  if (shot_idx < n_shots) {
    wf[sources_i[shot_idx]] += f[shot_idx];
  }
}

__global__ void record_receivers(float *__restrict r,
                                 float const *__restrict wf,
                                 int const *__restrict receivers_i,
                                 int n_shots) {
  auto shot_idx{blockIdx.x * blockDim.x + threadIdx.x};
  if (shot_idx < n_shots) {
    r[shot_idx] = wf[receivers_i[shot_idx]];
  }
}

float run() {
  int n_shots = 10;
  int ny = 200;
  int nx = 200;
  int nt = 1000;
  float dt = 0.0005f;
  float dy = 5.0f;
  float dx = 5.0f;
  float freq = 25.0f;
  float peak_time = 1.5f / freq;
  float *v2dt2 = (float *)malloc(ny * nx * sizeof(float));
  float *v2dt2_d;
  float *wfc_d;
  float *wfp_d;
  float *source_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
  float *source_amplitudes_d;
  float *receiver_amplitudes = (float *)malloc(n_shots * nt * sizeof(float));
  float *receiver_amplitudes_d;
  int *source_locations = (int *)malloc(n_shots * sizeof(int));
  int *source_locations_d;
  int *receiver_locations = (int *)malloc(n_shots * sizeof(int));
  int *receiver_locations_d;
  float one_over_dx2 = 1.0f / (dx * dx);
  float one_over_dy2 = 1.0f / (dy * dx);
  dim3 dimBlock(32, 32, 1);
  dim3 dimGrid((nx - 4 + 31) / 32, (ny - 4 + 31) / 32, n_shots);
  dim3 dimBlock_srcrcv(32, 1, 1);
  dim3 dimGrid_srcrcv((n_shots + 31) / 32, 1, 1);

  for (int y = 0; y < ny; ++y) {
    for (int x = 0; x < ny; ++x) {
      v2dt2[y * nx + x] = 1500 * 1500 * dt * dt;
    }
  }
  for (int i = 0; i < nt; ++i) {
    float t = i * dt - peak_time;
    float a = M_PI * freq * t;
    a = a * a;
    for (int shot = 0; shot < n_shots; ++shot) {
      source_amplitudes[i * n_shots + shot] = (1 - 2 * a) * expf(-a);
    }
  }
  for (int shot = 0; shot < n_shots; ++shot) {
    source_locations[shot] = shot * ny * nx + 2 * nx + 2;
    receiver_locations[shot] = shot * ny * nx + 2 * nx + 2;
  }

  cudaMalloc(&v2dt2_d, ny * nx * sizeof(float));
  cudaMalloc(&wfc_d, n_shots * ny * nx * sizeof(float));
  cudaMalloc(&wfp_d, n_shots * ny * nx * sizeof(float));
  cudaMalloc(&source_amplitudes_d, n_shots * nt * sizeof(float));
  cudaMalloc(&receiver_amplitudes_d, n_shots * nt * sizeof(float));
  cudaMalloc(&source_locations_d, n_shots * sizeof(int));
  cudaMalloc(&receiver_locations_d, n_shots * sizeof(int));

  cudaMemcpy(v2dt2_d, v2dt2, ny * nx * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(source_amplitudes_d, source_amplitudes,
             n_shots * nt * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(source_locations_d, source_locations, n_shots * sizeof(float),
             cudaMemcpyHostToDevice);
  cudaMemcpy(receiver_locations_d, receiver_locations, n_shots * sizeof(float),
             cudaMemcpyHostToDevice);
  cudaMemset(wfc_d, 0, n_shots * ny * nx * sizeof(float));
  cudaMemset(wfp_d, 0, n_shots * ny * nx * sizeof(float));

  for (int t = 0; t < nt; ++t) {
    step<<<dimGrid, dimBlock>>>(wfc_d, wfp_d, v2dt2_d, one_over_dx2,
                                one_over_dy2, n_shots, ny, nx);
    add_sources<<<dimGrid_srcrcv, dimBlock_srcrcv>>>(
        wfp_d, source_amplitudes_d + t * n_shots, source_locations_d, n_shots);
    record_receivers<<<dimGrid_srcrcv, dimBlock_srcrcv>>>(
        receiver_amplitudes_d + t * n_shots, wfc_d, receiver_locations_d,
        n_shots);
    std::swap(wfp_d, wfc_d);
  }

  cudaMemcpy(receiver_amplitudes, receiver_amplitudes_d,
             n_shots * nt * sizeof(float), cudaMemcpyDeviceToHost);
  float ans = *std::max_element(receiver_amplitudes,
                                receiver_amplitudes + n_shots * nt);

  cudaFree(v2dt2_d);
  cudaFree(wfc_d);
  cudaFree(wfp_d);
  cudaFree(source_amplitudes_d);
  cudaFree(receiver_amplitudes_d);
  cudaFree(source_locations_d);
  cudaFree(receiver_locations_d);
  free(v2dt2);
  free(source_amplitudes);
  free(receiver_amplitudes);
  free(source_locations);
  free(receiver_locations);

  return ans;
}

}  // namespace

int main(void) {
  float ans = 0.0f;

  // warmup
  for (int i = 0; i < 3; ++i) {
    ans = run();
  }

  // timing
  clock_t t0 = clock();
  for (int i = 0; i < 10; ++i) {
    run();
  }
  clock_t t1 = clock();
  printf("CUDA: %f %g\n", (float)(t1 - t0) / CLOCKS_PER_SEC, ans);
}

To run:

!g++ -Ofast -march=native wave.cpp
!./a.out
!nvcc wave.cu
!./a.out
!python wave.py

Output:

C++: 2.859534 13.1366
CUDA: 0.272206 13.1366
eager cpu 70.94255131099999 tensor(13.1366)
eager gpu 4.180636458000009 tensor(13.1366, device='cuda:0')
compile cpu 19.22466220399997 tensor(13.1366)
compile gpu 1.8795379320000052 tensor(13.1366, device='cuda:0')
jitscript cpu 23.204399749000004 tensor(13.1366)
jitscript gpu 1.8472471919999975 tensor(13.1366, device='cuda:0')

The fastest PyTorch on the CPU is 19s (torch.compile), compared to 2.9s for C++.
The fastest PyTorch on the GPU is 1.8s (both torch.compile and jit.script), compared to 0.3s for CUDA.

Your profiling on the GPU is I valid since you are not synchronizing the code (unless Iā€™m missing it somewhere). Besides that you could try to use CUDA Graphs in PyTorcj to reduce the dispatching and kernel launch overhead as it seems these take the majority of the time while the actual workload is not dominating the timeline.

1 Like

Thank you for your reply and suggestion, @ptrblck.

I do not believe that explicit synchronization is necessary as the blocking calls (cudaMemcpy, cudaFree, etc.) will cause implicit synchronization. Nevertheless, I added a call to cudaDeviceSynchronize() at the end of the run function and re-ran the timing. As expected, the results were similar.

I modified the Python code to use graphs. I used two approaches.

The first (the new Run class) captures a graph of the entire loop over time steps in the constructor. The run method then only needs to replay the graph. Note that this method has a small unfair advantage over the C++/CUDA and other Python approaches, as it only allocates memory and performs setup when the object is created, rather than for every run.

The second (closure_step) only creates a graph that captures two time steps. The loop over time steps is thus only performed (nt-1)//2 times, each replaying this graph. In order for the inputs and outputs of the graph to be the same memory address each time, I copy the source amplitudes for those two time steps into a tensor source_amplitudes_two, which is what is used inside the graph, and do something similar for the recorded receiver data for those two time steps.

import math
import torch
from torch import Tensor
from timeit import timeit


def ricker(freq: float, length: int, dt: float, peak_time: float) -> Tensor:
    """Ricker wavelet for source."""
    t = torch.arange(length) * dt - peak_time
    a = math.pi**2 * freq**2 * t**2
    return (1 - 2 * a) * torch.exp(-a)


def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
         one_over_dx2: float, one_over_dy2: float):
    """One finite difference time step of wave equation.

    Args:
       wfc: current time step wavefield, [shot, x, y]
       wfp: previous time step wavefield
       v2dt2: velocity^2 * dt^2
       one_over_dx2/dy2: 1/dx^2
    """
    d2wdx2 = (
        -5 / 2 * wfc[:, 2:-2, 2:-2]
        + 4 / 3 * (wfc[:, 3:-1, 2:-2] + wfc[:, 1:-3, 2:-2])
        + -1 / 12 * (wfc[:, 4:, 2:-2] + wfc[:, :-4, 2:-2])
    ) * one_over_dx2
    d2wdy2 = (
        -5 / 2 * wfc[:, 2:-2, 2:-2]
        + 4 / 3 * (wfc[:, 2:-2, 3:-1] + wfc[:, 2:-2, 1:-3])
        + -1 / 12 * (wfc[:, 2:-2, 4:] + wfc[:, 2:-2, :-4])
    ) * one_over_dy2

    # wavefield at next time step
    wfp[:, 2:-2, 2:-2] = (
        v2dt2 * (d2wdx2 + d2wdy2)
        + 2 * wfc[:, 2:-2, 2:-2]
        - wfp[:, 2:-2, 2:-2]
    )
    return wfp, wfc


class Run():
    def __init__(self, step_fn, device,
                 n_shots: int = 10, ny: int = 200, nx: int = 200,
                 nt: int = 1000, dt: float = 0.0005,
                 dy: float = 5, dx: float = 5, freq: float = 25):

        # Setup input velocity, source, receiver
        v = 1500*torch.ones(ny, nx, device=device)
        self.v2dt2 = v[2:-2, 2:-2]**2 * dt**2
        self.wfc = torch.zeros(n_shots, ny, nx, device=device)
        self.wfp = torch.zeros(n_shots, ny, nx, device=device)
        self.one_over_dx2 = 1 / dx**2
        self.one_over_dy2 = 1 / dy**2
        self.nt = nt
        self.step_fn = step_fn
        self.source_amplitudes = (
            ricker(freq, nt, dt, 1.5/freq)
            .to(device)
            .reshape(-1, 1)
            .repeat(1, n_shots)
        )
        self.sources_batch = torch.arange(n_shots, dtype=torch.long,
                                          device=device)
        self.sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
                                        device=device)
        self.sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
                                        device=device)
        self.receivers_batch = torch.arange(n_shots, dtype=torch.long,
                                            device=device)
        self.receivers_y = 2 * torch.ones(n_shots, dtype=torch.long,
                                          device=device)
        self.receivers_x = 2 * torch.ones(n_shots, dtype=torch.long,
                                          device=device)
        self.receiver_amplitudes = torch.zeros(nt, n_shots, device=device)
        # Warmup
        s = torch.cuda.Stream()
        s.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(s):
            for _ in range(3):
                self.loop()
        torch.cuda.current_stream().wait_stream(s)

        # Capture graph
        self.g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.g):
            self.loop()

    def loop(self):
        self.wfc.fill_(0)
        self.wfp.fill_(0)
        wfc = self.wfc
        wfp = self.wfp
        # Loop over time steps
        for t in range(self.nt-1):
            wfc, wfp = self.step_fn(wfc, wfp, self.v2dt2,
                                    self.one_over_dx2, self.one_over_dy2)
            wfc[self.sources_batch,
                self.sources_y,
                self.sources_x] += self.source_amplitudes[t]
            self.receiver_amplitudes[t] = \
                wfp[self.receivers_batch, self.receivers_x, self.receivers_y]

    def run(self):
        self.g.replay()
        return self.receiver_amplitudes.max()


def closure_step(step_fn, device):
    def run(n_shots: int = 10, ny: int = 200, nx: int = 200,
            nt: int = 1000, dt: float = 0.0005,
            dy: float = 5, dx: float = 5, freq: float = 25):

        # Setup input velocity, source, receiver
        v = 1500*torch.ones(ny, nx, device=device)
        v2dt2 = v[2:-2, 2:-2]**2 * dt**2
        wfc = torch.zeros(n_shots, ny, nx, device=device)
        wfp = torch.zeros(n_shots, ny, nx, device=device)
        one_over_dx2 = 1 / dx**2
        one_over_dy2 = 1 / dy**2
        source_amplitudes = (
            ricker(freq, nt, dt, 1.5/freq)
            .to(device)
            .reshape(-1, 1)
            .repeat(1, n_shots)
        )
        sources_batch = torch.arange(n_shots, dtype=torch.long,
                                     device=device)
        sources_y = 2 * torch.ones(n_shots, dtype=torch.long,
                                   device=device)
        sources_x = 2 * torch.ones(n_shots, dtype=torch.long,
                                   device=device)
        receivers_batch = torch.arange(n_shots, dtype=torch.long,
                                       device=device)
        receivers_y = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
        receivers_x = 2 * torch.ones(n_shots, dtype=torch.long, device=device)
        receiver_amplitudes = torch.zeros(nt, n_shots, device=device)

        source_amplitudes_two = torch.zeros(2, n_shots, device=device)
        receiver_amplitudes_two = torch.zeros(2, n_shots, device=device)

        def two_steps(wfc, wfp, source_amplitudes_two,
                      receiver_amplitudes_two):
            for i in range(2):
                wfc, wfp = step_fn(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
                wfc[sources_batch,
                    sources_y,
                    sources_x] += source_amplitudes_two[i]
                receiver_amplitudes_two[i] = \
                    wfp[receivers_batch, receivers_x, receivers_y]

        # Warmup
        s = torch.cuda.Stream()
        s.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(s):
            for _ in range(3):
                two_steps(wfc, wfp, source_amplitudes_two, receiver_amplitudes_two)
        torch.cuda.current_stream().wait_stream(s)

        # Capture graph
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            two_steps(wfc, wfp, source_amplitudes_two, receiver_amplitudes_two)

        # Loop over graph
        wfc.fill_(0)
        wfp.fill_(0)
        for t in range(0, nt-1, 2):
            source_amplitudes_two[:] = source_amplitudes[t:t+2]
            g.replay()
            receiver_amplitudes[t:t+2] = receiver_amplitudes_two

        # Do something with output so compiler doesn't optimise away
        return receiver_amplitudes.max()
    return run


def timing(name, run_fn, n_warmup=3):
    ans = 0
    for _ in range(n_warmup):
        ans = run_fn()
    print(name, timeit(lambda: run_fn(), number=10), ans)


def eager():
    with torch.no_grad():
        timing('eager', Run(step, torch.device('cuda:0')).run)
        timing('eager step', closure_step(step, torch.device('cuda:0')))


def torchcompile():
    with torch.no_grad():
        timing('compile', Run(torch.compile(step, fullgraph=True),
                              torch.device('cuda:0')).run)
        timing('compile step',
               closure_step(torch.compile(step, fullgraph=True),
                            torch.device('cuda:0')))


def jitscript():
    with torch.no_grad():
        timing('jitscript', Run(torch.jit.script(step),
                                torch.device('cuda:0')).run)
        timing('jitscript step', closure_step(torch.jit.script(step),
                                              torch.device('cuda:0')))


eager()
torchcompile()
jitscript()

This time, the results were:

CUDA (original): 0.201396 13.1366
CUDA (extra synchronization): 0.199689 13.1366
eager 2.4772268719998465 tensor(13.1366, device='cuda:0')
eager step 1.6758502800003043 tensor(13.1366, device='cuda:0')
compile 0.592774092000127 tensor(13.1366, device='cuda:0')
compile step 0.9349224309999045 tensor(13.1366, device='cuda:0')
jitscript 0.707418725000025 tensor(13.1366, device='cuda:0')
jitscript step 1.0656032869997034 tensor(13.1366, device='cuda:0')

( step refers to the second approach, where only two steps are captured by the graph)

While somewhat improved, there is unfortunately still a substantial (about 3X) performance difference between my hand-coded but not substantially optimized CUDA and the fastest PyTorch approach (graph of the full loop, using torch.compile on the step function).

1 Like

To isolate the computation more clearly, here is a modification that just performs the time stepping (no sources/receivers). I have written three versions - one for torch.compile, one directly in Triton, and one in CUDA. They are written in what I think is the most natural way for each approach, which results in small differences between them (the torch.compile returns its result in a new Tensor, while the other two overwrite the previous wavefield, for example), but I hope you will agree that the results are still reasonable measures of the performance of each approach.

Timings:

  • torch.compile: 0.13s
  • Triton: 0.1s
  • CUDA: 0.08s

It seems that in this case torch.compile works pretty well. I will continue to explore. I know this is not a typical deep learning application, and so it is understandable that it is not a priority for PyTorch, but I think there are many scientific and engineering applications that are similar and would benefit from being able to get good forward, and in some cases (such as mine) backpropagation, performance, using PyTorch.

torch.compile:

import torch
from torch import Tensor


def step(wfc: Tensor, wfp: Tensor, v2dt2: Tensor,
         one_over_dx2: float, one_over_dy2: float) -> Tensor:
    wfcp = torch.nn.functional.pad(wfc, (2, 2, 2, 2))
    return (
        v2dt2[None] * (
            (
                -5 / 2 * wfc
                + 4 / 3 * (wfcp[:, 3:-1, 2:-2] + wfcp[:, 1:-3, 2:-2])
                + -1 / 12 * (wfcp[:, 4:, 2:-2] + wfcp[:, :-4, 2:-2])
            ) * one_over_dx2 +
            (
                -5 / 2 * wfc
                + 4 / 3 * (wfcp[:, 2:-2, 3:-1] + wfcp[:, 2:-2, 1:-3])
                + -1 / 12 * (wfcp[:, 2:-2, 4:] + wfcp[:, 2:-2, :-4])
            ) * one_over_dy2
        )
        + 2 * wfc
        - wfp
    )


step = torch.compile(step)

n_shots = 50
ny = 500
nx = 500
nt = 100
torch.manual_seed(1)
v2dt2 = torch.rand(ny, nx, device=torch.device('cuda')) - 0.5
wfc = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
wfp = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
one_over_dy2 = one_over_dx2 = 1 / 5**2

for _ in range(nt):
    wfn = step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
    wfc, wfp = wfn, wfc

import time
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nt):
    wfn = step(wfc, wfp, v2dt2, one_over_dx2, one_over_dy2)
    wfc, wfp = wfn, wfc
torch.cuda.synchronize()
t1 = time.time()
print(t1 - t0)

Triton:

import torch
import triton
import triton.language as tl


@triton.jit
def step(wfc_ptr, wfp_ptr, v2dt2_ptr, one_over_dx2, one_over_dy2,
         n_shots, ny, nx,
         BLOCKSIZE_Y: tl.constexpr, BLOCKSIZE_X: tl.constexpr):
    pid_x = tl.program_id(axis=0)
    pid_y = tl.program_id(axis=1)
    pid_shot = tl.program_id(axis=2)
    y_start = pid_y * BLOCKSIZE_Y
    x_start = pid_x * BLOCKSIZE_X
    stride_shot = ny * nx
    offs_y = y_start + tl.arange(0, BLOCKSIZE_Y)
    offs_x = x_start + tl.arange(0, BLOCKSIZE_X)
    offs = pid_shot * stride_shot + offs_y[:, None] * nx + offs_x[None, :]
    wfc_c = tl.load(wfc_ptr + offs,
                    mask=(offs_y[:, None] < ny) &
                         (offs_x[None, :] < nx),
                    other=0.0)
    wfc_up = tl.load(wfc_ptr + offs - nx,
                     mask=(0 <= offs_y[:, None] - 1) &
                          (offs_y[:, None] - 1 < ny) &
                          (offs_x[None, :] < nx),
                     other=0.0)
    wfc_up2 = tl.load(wfc_ptr + offs - 2*nx,
                      mask=(0 <= offs_y[:, None] - 2) &
                           (offs_y[:, None] - 2 < ny) &
                           (offs_x[None, :] < nx),
                      other=0.0)
    wfc_down = tl.load(wfc_ptr + offs + nx,
                       mask=(offs_y[:, None] + 1 < ny) &
                            (offs_x[None, :] < nx),
                       other=0.0)
    wfc_down2 = tl.load(wfc_ptr + offs + 2*nx,
                        mask=(offs_y[:, None] + 2 < ny) &
                             (offs_x[None, :] < nx),
                        other=0.0)
    wfc_left = tl.load(wfc_ptr + offs - 1,
                       mask=(offs_y[:, None] < ny) &
                            (0 <= offs_x[None, :] - 1) &
                            (offs_x[None, :] - 1 < nx),
                       other=0.0)
    wfc_left2 = tl.load(wfc_ptr + offs - 2,
                        mask=(offs_y[:, None] < ny) &
                             (0 <= offs_x[None, :] - 2) &
                             (offs_x[None, :] - 2 < nx),
                        other=0.0)
    wfc_right = tl.load(wfc_ptr + offs + 1,
                        mask=(offs_y[:, None] < ny) &
                             (offs_x[None, :] + 1 < nx),
                        other=0.0)
    wfc_right2 = tl.load(wfc_ptr + offs + 2,
                         mask=(offs_y[:, None] < ny) &
                              (offs_x[None, :] + 2 < nx),
                         other=0.0)

    wfp = tl.load(wfp_ptr + offs,
                  mask=(offs_y[:, None] < ny) &
                       (offs_x[None, :] < nx),
                  other=0.0)
    v2dt2 = tl.load(v2dt2_ptr + offs_y[:, None] * nx + offs_x[None, :],
                    mask=(offs_y[:, None] < ny) &
                         (offs_x[None, :] < nx),
                    other=0.0)

    tl.store(wfp_ptr + offs,
             v2dt2 * (
                 (
                     -5 / 2 * wfc_c
                     + 4 / 3 * (wfc_up + wfc_down)
                     + -1 / 12 * (wfc_up2 + wfc_down2)
                 ) * one_over_dx2 +
                 (
                     -5 / 2 * wfc_c
                     + 4 / 3 * (wfc_left + wfc_right)
                     + -1 / 12 * (wfc_left2 + wfc_right2)
                 ) * one_over_dy2
             )
             + 2 * wfc_c
             - wfp,
             mask=(offs_y[:, None] < ny) & (offs_x[None, :] < nx))


n_shots = 50
ny = 500
nx = 500
nt = 100
torch.manual_seed(1)
v2dt2 = torch.rand(ny, nx, device=torch.device('cuda')) - 0.5
wfc = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
wfp = torch.rand(n_shots, ny, nx, device=torch.device('cuda')) - 0.5
one_over_dy2 = one_over_dx2 = 1 / 5**2

grid = (triton.cdiv(nx, 32), triton.cdiv(ny, 32), n_shots)

for _ in range(nt):
  step[grid](wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots,
           ny, nx, BLOCKSIZE_Y=32, BLOCKSIZE_X=32)
  wfc, wfp = wfp, wfc

import time
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nt):
    step[grid](wfc, wfp, v2dt2, one_over_dx2, one_over_dy2, n_shots,
               ny, nx, BLOCKSIZE_Y=32, BLOCKSIZE_X=32)
    wfc, wfp = wfp, wfc
torch.cuda.synchronize()
t1 = time.time()
print(t1 - t0)

CUDA:

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

namespace {

__global__ void step(float const *__restrict const wfc,
                     float *__restrict const wfp,
                     float const *__restrict const v2dt2,
                     float const one_over_dx2, float const one_over_dy2,
                     int const n_shots, int const ny, int const nx) {
  int x = blockIdx.x * blockDim.x + threadIdx.x + 2;
  int y = blockIdx.y * blockDim.y + threadIdx.y + 2;
  int shot = blockIdx.z * blockDim.z + threadIdx.z;
  int i = shot * ny * nx + y * nx + x;
  if (y < ny - 2 && x < nx - 2) {
    wfp[i] =
        v2dt2[y * nx + x] *
            ((-5.0f / 2.0f * wfc[i] +
              4.0f / 3.0f * (wfc[i + nx] + wfc[i - nx]) +
              -1.0f / 12.0f * (wfc[i + 2 * nx] + wfc[i - 2 * nx])) *
                 one_over_dx2 +
             (-5.0f / 2.0f * wfc[i] + 4.0f / 3.0f * (wfc[i + 1] + wfc[i - 1]) +
              -1.0f / 12.0f * (wfc[i + 2] + wfc[i - 2])) *
                 one_over_dy2) +
        2 * wfc[i] - wfp[i];
  }
}

}  // namespace

int main() {
  int n_shots = 50;
  int ny = 504;
  int nx = 504;
  int nt = 100;
  float *v2dt2 = (float *)calloc(ny * nx, sizeof(float));
  float *v2dt2_d;
  float *wfc = (float *)calloc(n_shots * ny * nx, sizeof(float));
  float *wfc_d;
  float *wfp = (float *)calloc(n_shots * ny * nx, sizeof(float));
  float *wfp_d;
  float one_over_dx2 = 1.0f / 25.0f;
  float one_over_dy2 = 1.0f / 25.0f;
  dim3 dimBlock(32, 32, 1);
  dim3 dimGrid((nx - 4 + 31) / 32, (ny - 4 + 31) / 32, n_shots);

  srand(1);
  for (int y = 2; y < ny - 2; ++y) {
    for (int x = 2; x < ny - 2; ++x) {
      v2dt2[y * nx + x] = (float)rand() / RAND_MAX - 0.5f;
    }
  }
  for (int shot = 0; shot < n_shots; ++shot) {
    for (int y = 2; y < ny - 2; ++y) {
      for (int x = 2; x < ny - 2; ++x) {
        wfc[shot * ny * nx + y * nx + x] = (float)rand() / RAND_MAX - 0.5f;
      }
    }
  }
  for (int shot = 0; shot < n_shots; ++shot) {
    for (int y = 2; y < ny - 2; ++y) {
      for (int x = 2; x < ny - 2; ++x) {
        wfp[shot * ny * nx + y * nx + x] = (float)rand() / RAND_MAX - 0.5f;
      }
    }
  }

  cudaMalloc(&v2dt2_d, ny * nx * sizeof(float));
  cudaMalloc(&wfc_d, n_shots * ny * nx * sizeof(float));
  cudaMalloc(&wfp_d, n_shots * ny * nx * sizeof(float));

  cudaMemcpy(v2dt2_d, v2dt2, ny * nx * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(wfc_d, wfc, n_shots * ny * nx * sizeof(float),
             cudaMemcpyHostToDevice);
  cudaMemcpy(wfp_d, wfp, n_shots * ny * nx * sizeof(float),
             cudaMemcpyHostToDevice);

  cudaDeviceSynchronize();
  clock_t t0 = clock();
  for (int t = 0; t < nt; ++t) {
    step<<<dimGrid, dimBlock>>>(wfc_d, wfp_d, v2dt2_d, one_over_dx2,
                                one_over_dy2, n_shots, ny, nx);
    float *tmp = wfc_d;
    wfc_d = wfp_d;
    wfp_d = tmp;
  }
  cudaDeviceSynchronize();
  clock_t t1 = clock();
  printf("CUDA: %f\n", (float)(t1 - t0) / CLOCKS_PER_SEC);

  cudaFree(v2dt2_d);
  cudaFree(wfc_d);
  cudaFree(wfp_d);
  free(v2dt2);
  free(wfc);
  free(wfp);
}

As already mentioned, PyTorch will show its dispatching and Autograd overheads for tiny workloads, which cannot saturate the device. CUDA Graphs helps in these cases in case you are launching a lot of tiny kernels, but you might not see a huge difference if the actual workload itself is small.