CUDA error: an illegal memory access was encountered (pytorch CUDA extension) on GPU and Segmentation fault on CPU

HI, I am trying to call the native functions. However, I met ‘CUDA error: an illegal memory access was encountered’ when I ran the CUDA version and it gave ‘Segmentation fault’ when I switched to the CPU version.

I tested the code on pytorch 1.5/1.6/1.7 with cuda 9.2, all pytorch version gave the same error.

What I was trying to do:
First wrap the functions:

#include <torch/extension.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>

std::tuple<at::Tensor, at::Tensor, at::Tensor> layer_norm_forward_cpu(
    const at::Tensor & input,
    const at::Tensor & weight,
    const at::Tensor & bias,
    int64_t M, int64_t N, double eps) {
    return at::native::layer_norm_cpu(input, weight, bias, M, N, eps);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_layer_norm_cpu(
    const at::Tensor & grad_out,
    const at::Tensor & input,
    const at::Tensor & mean,
    const at::Tensor & rstd,
    const at::Tensor & weight,
    int64_t M, int64_t N, std::array<bool,3> output_mask) {
    return at::native::layer_norm_backward_cpu(grad_out, input, mean, rstd, weight, M, N, output_mask);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> layer_norm_forward_cuda(
    const at::Tensor & input,
    const at::Tensor & weight,
    const at::Tensor & bias,
    int64_t M, int64_t N, double eps) {
    return at::native::layer_norm_cuda(input, weight, bias, M, N, eps);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_layer_norm_cuda(
    const at::Tensor & grad_out,
    const at::Tensor & input,
    const at::Tensor & mean,
    const at::Tensor & rstd,
    const at::Tensor & weight,
    int64_t M, int64_t N, std::array<bool,3> output_mask) {
    return at::native::layer_norm_backward_cuda(grad_out, input, mean, rstd, weight, M, N, output_mask);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("layer_norm_forward_cpu",  &layer_norm_forward_cpu, "layer norm forward (cpu version)");
    m.def("layer_norm_backward_cpu", &backward_layer_norm_cpu, "layer norm backward (cpu version)");
    m.def("layer_norm_forward_cuda", &layer_norm_forward_cuda, "layer norm forward (cuda version)");
    m.def("layer_norm_backward_cuda",&backward_layer_norm_cuda, "layer norm backward (cuda version)");
}

and then call the functions in

import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

import native

class layer_norm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, normalized_shape, weight, bias, eps, training):
        N = 1
        if isinstance(normalized_shape, int):
            N = normalized_shape
        elif isinstance(normalized_shape, (list, tuple)):
            for i in normalized_shape:
                N *= i
        else:
            raise RuntimeError("unexpected type of normalized_shape".format(type(normalized_shape)))
        M = x.nelement() // N

        if x.is_cuda:

            y, mean, rstd = native.layer_norm_forward_cuda(x, weight, bias, M, N, eps)
        else:
            y, mean, rstd = native.layer_norm_forward_cpu(x, weight, bias, M, N, eps)

        if training:
            ctx.layer_norm_input = x
            ctx.layer_norm_parameters = (mean, rstd, weight, M, N)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.layer_norm_input
        mean, rstd, weight, M, N = ctx.layer_norm_parameters

        output_mask = [True, True, True]
        if grad_output.is_cuda:
            grad_input, grad_weight, grad_bias = native.layer_norm_backward_cuda(grad_output, x, mean, rstd, weight, M, N, output_mask)
        else:
            grad_input, grad_weight, grad_bias = native.layer_norm_backward_cpu(grad_output, x, mean, rstd, weight, M, N, output_mask)
        ctx.layer_norm_input = None
        ctx.layer_norm_parameters = None
        return grad_input, None, grad_weight, grad_bias, None, None, None, None

class LayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True):
        nn.LayerNorm.__init__(self, normalized_shape, eps=eps, elementwise_affine=elementwise_affine)

    def forward(self, x):
        y = layer_norm.apply(x, self.normalized_shape, self.weight, self.bias, self.eps, self.training)
        return y


if __name__ == "__main__":
    seed = 2809
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic=True #https://github.com/pytorch/pytorch/issues/8019

    model = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            LayerNorm([64,56,56])
            )
    print(model)

    #model = model.cuda()
    model.train()
    iteration = 10
    for i in range(iteration):
        print("index: ", i)
        x = torch.rand(512,64,56,56)
        x = x - 0.5
        #x = x.cuda()

        y = model(x)
        z = y.sum()
        z.backward()

I also uploaded all the code to GitHub - irving-qin/nativefunctions
Just run bash install.sh for testing.

The forward process of the wrapped layer norm seemed to be normal. However, it throwed errors in the backward function.

Thanks so much if anyone could give me some tips.

Your current code cannot be build using a recent source build and fails with:

/workspace/src/nativefunctions/native.cpp:39:66: error: could not convert ‘mean’ from ‘const at::Tensor’ to ‘c10::IntArrayRef’ {aka ‘c10::ArrayRef<long int>’}
   39 |     return at::native::layer_norm_backward_cuda(grad_out, input, mean, rstd, weight, M, N, output_mask);
      |                                                                  ^~~~
      |                                                                  |
      |                                                                  const at::Tensor

EDIT: it seems the IntArrayRef normalized_shape input is missing as seen here.

Hi, @ptrblck Thank so much for your attention on the question.

It seems the api was changed since pytorch 1.8.

In my eary trival, I didn’t build the code again the source code of the pytorch. Instead I build it with a pre-installed version.

I found the function definition in path something like
~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/torch/include/ATen/NativeFunctions.h
The path prefix might be different on different machines.

If I ran
grep layer_norm ~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/torch/include/ATen/NativeFunctions.h on pytorch1.5/1.6/1.7 it gave

CAFFE2_API Tensor layer_norm(const Tensor & input, IntArrayRef normalized_shape, const Tensor & weight={}, const Tensor & bias={}, double eps=1e-05, bool cudnn_enable=true);
CAFFE2_API std::tuple<Tensor,Tensor,Tensor> layer_norm_cpu(const Tensor & input, const Tensor & weight, const Tensor & bias, int64_t M, int64_t N, double eps);
CAFFE2_API std::tuple<Tensor,Tensor,Tensor> layer_norm_cuda(const Tensor & input, const Tensor & weight, const Tensor & bias, int64_t M, int64_t N, double eps);
CAFFE2_API std::tuple<Tensor,Tensor,Tensor> layer_norm_backward_cpu(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & rstd, const Tensor & weight, int64_t M, int64_t N, std::array<bool,3> output_mask);
CAFFE2_API std::tuple<Tensor,Tensor,Tensor> layer_norm_backward_cuda(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & rstd, const Tensor & weight, int64_t M, int64_t N, std::array<bool,3> output_mask);

I just tried the grep on pytorch 1.8, it gave

TORCH_API Tensor layer_norm(const Tensor & input, IntArrayRef normalized_shape, const Tensor & weight={}, const Tensor & bias={}, double eps=1e-05, bool cudnn_enable=true);
TORCH_API std::tuple<Tensor,Tensor,Tensor> layer_norm_cpu(const Tensor & input, IntArrayRef normalized_shape, const Tensor & weight, const Tensor & bias, double eps);
TORCH_API std::tuple<Tensor,Tensor,Tensor> layer_norm_cuda(const Tensor & input, IntArrayRef normalized_shape, const Tensor & weight, const Tensor & bias, double eps);
TORCH_API std::tuple<Tensor,Tensor,Tensor> math_native_layer_norm(const Tensor & input, IntArrayRef normalized_shape, const Tensor & weight, const Tensor & bias, double eps);
TORCH_API std::tuple<Tensor,Tensor,Tensor> layer_norm_backward_cpu(const Tensor & grad_out, const Tensor & input, IntArrayRef normalized_shape, const Tensor & mean, const Tensor & rstd, const Tensor & weight, const Tensor & bias, std::array<bool,3> output_mask);
TORCH_API std::tuple<Tensor,Tensor,Tensor> layer_norm_backward_cuda(const Tensor & grad_out, const Tensor & input, IntArrayRef normalized_shape, const Tensor & mean, const Tensor & rstd, const Tensor & weight, const Tensor & bias, std::array<bool,3> output_mask);

I have updated the code on github repo (run git pull, please ) to make it support pytorch 1.8. However the error was still there.

Hi, I found out the reason for the issue. It is caused by the grad_output to be not contiguous. After adding grad_output = grad_output.contiguous(), the error is gone.