nvprof
shows pytorch using a cuda kernel with fp32 accumulation:
turing_fp16_s1688cudnn_fp16_256x128_ldg8_relu_f2f_exp_small_nhwc_tn_v1
on code
python
import torch
model = torch.nn.Sequential(
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.rand((200, 128, 8, 8), dtype=torch.half, device=device)
model = model.to(device).half()
print(model.forward(x))
C++ (slightly different)
#include <bits/stdc++.h>
#include <torch/torch.h>
using namespace torch;
const int filters = 128;
struct Block : nn::Module {
nn::Conv2d conv1 = nullptr;
nn::Conv2d conv2 = nullptr;
Block() {
conv1 = register_module("conv1", nn::Conv2d(nn::Conv2dOptions(filters, filters, 3).stride(1).padding(1)));
conv2 = register_module("conv2", nn::Conv2d(nn::Conv2dOptions(filters, filters, 3).stride(1).padding(1)));
}
Tensor forward(Tensor x) {
at::Tensor residual(x.clone());
x = conv1->forward(x);
x = relu(x);
x = conv2->forward(x);
x += residual;
x = relu(x);
return x;
}
};
int main() {
double sendSeconds = 0;
torch::Device device = torch::kCUDA;
nn::Sequential trunk = nn::Sequential();
for (int i = 0; i < 10; i++) {
trunk->push_back(Block());
}
NoGradGuard guard;
trunk->eval();
trunk->to(device, torch::kHalf);
for (int i = 0; i < 500; i++) {
Tensor tt = torch::rand({1, filters, 8, 8}).pin_memory();
//auto mid1 = std::chrono::steady_clock::now();
tt = tt.to(device, torch::kHalf);
//auto mid2 = std::chrono::steady_clock::now();
tt = trunk->forward(tt);
if (i == 499) {
std::cout << tt << std::endl;
}
}
}
I want to NOT use accumulation (example kernel: turing_h1688cudnn_128x128_ldg8_relu_exp_small_nhwc_tn_v1
(although different dimensions)).
How can I turn off fp32 accumulation and go straight through with fp16? I would like a solution for C++.