I am trying to check the performance gain of using a c++ extension instead of just python, when implementing custom conv2d operations.
When benchmarking it for a simple conv2d implementation, I although get worse performance with the c++ extension (83µs vs 56µs), despite virtually identical code. What could cause this ?
Here’s my python code:
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import time
import conv2d_cpp
def time_fn(fn, n=10000):
t0 = time.time()
for _ in range(n):
fn()
t = time.time()-t0
torch.cuda.synchronize()
print(f"{t/n*10**6:.2f}µs per loop")
def conv_2d(x: Tensor, kernel: Tensor, bias: Tensor) -> Tensor:
"""get each individual conv window to compute the whole conv in one matmul"""
assert(x.shape[0] == kernel.shape[0] and kernel.shape[1] == bias.shape[0])
output_shape = (x.shape[1] - kernel.shape[2]) + 1
im2col_matrix = F.unfold(x, kernel.shape[2])
im2col_conv = torch.matmul(kernel.moveaxis(0, 1).flatten(start_dim=1), im2col_matrix) + bias
im2col_conv = im2col_conv.reshape(kernel.shape[1], output_shape, output_shape)
return im2col_conv
def main():
device = 'cuda'
input = torch.randint(-255, 255, size=(1, 256, 256)).float().to(device)
kernel = torch.tensor([[[[8., 9., 7.], [4., 4., 2.], [-1., 6., -3.]]]]).to(device)
bias = torch.tensor([[0.06]]).to(device)
di, do = 16, 32
input, kernel, bias = input.repeat(di, 1, 1), kernel.repeat(di, do, 1, 1), bias.repeat(do, 1)
time_fn(lambda: conv2d_cpp.forward(input[None], kernel, bias))
time_fn(lambda: conv_2d(input, kernel, bias))
and my cpp code:
#include <torch/extension.h>
namespace F = torch::nn::functional;
torch::Tensor conv2d_forward(
torch::Tensor im, // should be of shape (bs, di, side, side)
torch::Tensor kernel, // (di, do, ks, ks)
torch::Tensor bias
) {
auto ks = kernel.size(3);
auto stride = ks/2;
auto side = im.size(2);
auto output_shape = side - ks + 1;
auto im2col_matrix = F::unfold(im, F::UnfoldFuncOptions({ks, ks}).padding(0).stride(stride));
auto im2col_conv = torch::matmul(kernel.moveaxis(0, 1).flatten(1), im2col_matrix) + bias;
return im2col_conv.reshape({kernel.size(1), output_shape, output_shape});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Implementation of a forward pass of conv2d in C++";
m.def("forward", &conv2d_forward, "Conv2d forward");
}