Usage of __restrict__ in CUDA

Dear PyTorch Team,

thank you very much for the library, which is very pleasant to use and has great documentation!

In the CUDA example, you use const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> when one would use __restrict__ scalar_t* without the accessor. This seems like a very nice solution that I would also want to use within some CUDA code. However, if I read TensorAccessor.h correctly, the restricted pointer is the member of a class and according to this discussion in the Nvidia forums that is ignored. To validate this, I made this little example in the compiler explorer, which would show that the restrict is ignored. Could you maybe comment on this, as most likely I am missing something here?

Thank you
Lukas

Compiler Explorer Code (This looks much nicer on the compiler explorer website)

// Some code copied from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/TensorAccessor.h

// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
// is used to enable the __restrict__ keyword/modifier for the data
// passed to cuda.
template <typename T>
struct DefaultPtrTraits {
  typedef T* PtrType;
};

template <typename T>
struct RestrictPtrTraits {
  typedef T* __restrict__ PtrType;
};

class wrapper{
public:
    RestrictPtrTraits<double>::PtrType ptr;
};

__global__ void assign(DefaultPtrTraits<double>::PtrType input, DefaultPtrTraits<double>::PtrType output) {
    int tid = blockIdx.x;
    {
        double la = input[tid];
        output[tid] = la;
        la = input[tid];
        output[tid] = la;
    }
        
}

__global__ void assign_restrict_direct(RestrictPtrTraits<double>::PtrType input, RestrictPtrTraits<double>::PtrType output) {
    int tid = blockIdx.x;
    {
        double la = input[tid];
        output[tid] = la;
        la = input[tid];
        output[tid] = la;
    }
        
}


__global__ void assign_restrict_struct(wrapper input, wrapper output) {
    int tid = blockIdx.x;
    {
        double la = input.ptr[tid];
        output.ptr[tid] = la;
        la = input.ptr[tid];
        output.ptr[tid] = la;
    }
        
}

Compiler Explorer Assembly (NVCC 11.5.0 sm_52)

Note that _Z22assign_restrict_struct7wrapperS_ uses two loads and stores and hence doesn’t exploit the __restrict__ like _Z22assign_restrict_directPdS_.

.visible .entry _Z6assignPdS_(
        .param .u64 _Z6assignPdS__param_0,
        .param .u64 _Z6assignPdS__param_1
)
{

        ld.param.u64    %rd1, [_Z6assignPdS__param_0];
        ld.param.u64    %rd2, [_Z6assignPdS__param_1];
        cvta.to.global.u64      %rd3, %rd2;
        cvta.to.global.u64      %rd4, %rd1;
        mov.u32         %r1, %ctaid.x;
        mul.wide.s32    %rd5, %r1, 8;
        add.s64         %rd6, %rd4, %rd5;
        ld.global.f64   %fd1, [%rd6];
        add.s64         %rd7, %rd3, %rd5;
        st.global.f64   [%rd7], %fd1;
        ld.global.f64   %fd2, [%rd6];
        st.global.f64   [%rd7], %fd2;
        ret;

}
.visible .entry _Z22assign_restrict_directPdS_(
        .param .u64 _Z22assign_restrict_directPdS__param_0,
        .param .u64 _Z22assign_restrict_directPdS__param_1
)
{

        ld.param.u64    %rd1, [_Z22assign_restrict_directPdS__param_0];
        ld.param.u64    %rd2, [_Z22assign_restrict_directPdS__param_1];
        cvta.to.global.u64      %rd3, %rd2;
        cvta.to.global.u64      %rd4, %rd1;
        mov.u32         %r1, %ctaid.x;
        mul.wide.s32    %rd5, %r1, 8;
        add.s64         %rd6, %rd4, %rd5;
        ld.global.nc.f64        %fd1, [%rd6];
        add.s64         %rd7, %rd3, %rd5;
        st.global.f64   [%rd7], %fd1;
        ret;

}
.visible .entry _Z22assign_restrict_struct7wrapperS_(
        .param .align 8 .b8 _Z22assign_restrict_struct7wrapperS__param_0[8],
        .param .align 8 .b8 _Z22assign_restrict_struct7wrapperS__param_1[8]
)
{

        ld.param.u64    %rd1, [_Z22assign_restrict_struct7wrapperS__param_1];
        ld.param.u64    %rd2, [_Z22assign_restrict_struct7wrapperS__param_0];
        cvta.to.global.u64      %rd3, %rd2;
        cvta.to.global.u64      %rd4, %rd1;
        mov.u32         %r1, %ctaid.x;
        mul.wide.s32    %rd5, %r1, 8;
        add.s64         %rd6, %rd3, %rd5;
        ld.global.f64   %fd1, [%rd6];
        add.s64         %rd7, %rd4, %rd5;
        st.global.f64   [%rd7], %fd1;
        ld.global.f64   %fd2, [%rd6];
        st.global.f64   [%rd7], %fd2;
        ret;

}