How to export THNN C functions to pytorch by using FFI?

I have experimented a little with writing C extension for pytorch tutorial. This tutorial is very helpful and I added a THSize function and it worked too.

while starting to learn C and C++, I want to experiment on the C code of TH and THNN without waiting too long. After tried this tutorial, I realized that maybe I could try to put every TH and THNN function into the template of C extension tutorial, and export each C function into pytorch. If it works, then I can try to do a lot of printf in the C functions to experiment C functions from the python side.

So, I start to experiment:

  1. int THSize_isSameSizeAs(const int64_t *sizeA, int64_t dimsA, const int64_t *sizeB, int64_t dimsB), function starts like this seem very Ok, there is no error at least;
/* src/my_lib.c */
#include <TH/TH.h>


int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
    if (!THFloatTensor_isSameSizeAs(input1, input2))
        return 0;
    THFloatTensor_resizeAs(output, input1);
    THFloatTensor_cadd(output, input1, 1.0, input2);
    return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
    THFloatTensor_resizeAs(grad_input, grad_output);
    THFloatTensor_fill(grad_input, 1);
    return 1;
}


int THSize_isSameSizeAs(const int64_t *sizeA, int64_t dimsA, const int64_t *sizeB, int64_t dimsB) {
  int d;
  if (dimsA != dimsB)
    return 0;
  for(d = 0; d < dimsA; ++d)
  {
    if(sizeA[d] != sizeB[d])
      return 0;
  }
  return 1;
}
  1. However, when I tried to use the two functions of THNN/generic/BatchNormation.c, there was error when building it. below is my_lib.c and my_lib.h.
    How could I make this work? Thanks a lot
//my_lib.c

#include <TH/TH.h>
#include <THNN/THNN.h>


// #ifndef TH_GENERIC_FILE
// #define TH_GENERIC_FILE "generic/BatchNormalization.c"
// #else

void THNN_(BatchNormalization_updateOutput)(
  THNNState *state, THTensor *input, THTensor *output,
  THTensor *weight, THTensor *bias,
  THTensor *running_mean, THTensor *running_var,
  THTensor *save_mean, THTensor *save_std,
  bool train, double momentum, double eps)
{
  THTensor_(resizeAs)(output, input);
  long nInput = THTensor_(size)(input, 1);
  long f;
  ptrdiff_t n = THTensor_(nElement)(input) / nInput;

  #pragma omp parallel for
  for (f = 0; f < nInput; ++f) {
    THTensor *in = THTensor_(newSelect)(input, 1, f);
    THTensor *out = THTensor_(newSelect)(output, 1, f);

    real mean, invstd;

    if (train) {
      // compute mean per input
      accreal sum = 0;
      TH_TENSOR_APPLY(real, in, sum += *in_data;);

      mean = (real) sum / n;
      THTensor_(set1d)(save_mean, f, (real) mean);

      // compute variance per input
      sum = 0;
      TH_TENSOR_APPLY(real, in,
        sum += (*in_data - mean) * (*in_data - mean););

      if (sum == 0 && eps == 0.0) {
        invstd = 0;
      } else {
        invstd = (real) (1 / sqrt(sum/n + eps));
      }
      THTensor_(set1d)(save_std, f, (real) invstd);

      // update running averages
      THTensor_(set1d)(running_mean, f,
        (real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));

      accreal unbiased_var = sum / (n - 1);
      THTensor_(set1d)(running_var, f,
        (real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
    } else {
      mean = THTensor_(get1d)(running_mean, f);
      invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
    }

    // compute output
    real w = weight ? THTensor_(get1d)(weight, f) : 1;
    real b = bias ? THTensor_(get1d)(bias, f) : 0;

    TH_TENSOR_APPLY2(real, in, real, out,
      *out_data = (real) (((*in_data - mean) * invstd) * w + b););

    THTensor_(free)(out);
    THTensor_(free)(in);
  }
}

void THNN_(BatchNormalization_backward)(
  THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput,
  THTensor *gradWeight, THTensor *gradBias, THTensor *weight,
  THTensor *running_mean, THTensor *running_var,
  THTensor *save_mean, THTensor *save_std,
  bool train, double scale, double eps)
{
  THNN_CHECK_SHAPE(input, gradOutput);
  long nInput = THTensor_(size)(input, 1);
  long f;
  ptrdiff_t n = THTensor_(nElement)(input) / nInput;

  #pragma omp parallel for
  for (f = 0; f < nInput; ++f) {
    THTensor *in = THTensor_(newSelect)(input, 1, f);
    THTensor *gradOut = THTensor_(newSelect)(gradOutput, 1, f);
    real w = weight ? THTensor_(get1d)(weight, f) : 1;
    real mean, invstd;
    if (train) {
      mean = THTensor_(get1d)(save_mean, f);
      invstd = THTensor_(get1d)(save_std, f);
    } else {
      mean = THTensor_(get1d)(running_mean, f);
      invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
    }

    // sum over all gradOutput in feature plane
    accreal sum = 0;
    TH_TENSOR_APPLY(real, gradOut, sum += *gradOut_data;);

    // dot product of the Q(X) and gradOuput
    accreal dotp = 0;
    TH_TENSOR_APPLY2(real, in, real, gradOut,
      dotp += (*in_data - mean) * (*gradOut_data););

    if (gradInput) {
      THTensor_(resizeAs)(gradInput, input);
      THTensor *gradIn = THTensor_(newSelect)(gradInput, 1, f);

      if (train) {
        // when in training mode
        // Q(X) = X - E[x] ; i.e. input centered to zero mean
        // Y = Q(X) / σ    ; i.e. BN output before weight and bias
        // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w

        // projection of gradOutput on to output scaled by std
        real k = (real) dotp * invstd * invstd / n;
        TH_TENSOR_APPLY2(real, gradIn, real, in,
          *gradIn_data = (*in_data - mean) * k;);

        accreal gradMean = sum / n;
        TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
          *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;);

      } else {
        // when in evaluation mode
        // Q(X) = X - running_mean  ; i.e. input centered to zero mean
        // Y = Q(X) / running_std    ; i.e. BN output before weight and bias
        // dL/dX = w / running_std
        TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
          *gradIn_data = *gradOut_data * invstd * w;);
      }

      THTensor_(free)(gradIn);
    }

    if (gradWeight) {
      real val = THTensor_(get1d)(gradWeight, f);
      THTensor_(set1d)(gradWeight, f, val + scale * dotp * invstd);
    }

    if (gradBias) {
      real val = THTensor_(get1d)(gradBias, f);
      THTensor_(set1d)(gradBias, f, val + scale * sum);
    }

    THTensor_(free)(gradOut);
    THTensor_(free)(in);
  }
}

// my_lib.h

void THNN_(BatchNormalization_updateOutput)(
  THNNState *state, THTensor *input, THTensor *output,
  THTensor *weight, THTensor *bias,
  THTensor *running_mean, THTensor *running_var,
  THTensor *save_mean, THTensor *save_std,
  bool train, double momentum, double eps);

void THNN_(BatchNormalization_backward)(
  THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput,
  THTensor *gradWeight, THTensor *gradBias, THTensor *weight,
  THTensor *running_mean, THTensor *running_var,
  THTensor *save_mean, THTensor *save_std,
  bool train, double scale, double eps);

The error is following:

Traceback (most recent call last):
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/cffi/cparser.py", line 269, in _parse
    ast = _get_parser().parse(csource)
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/pycparser/c_parser.py", line 152, in parse
    debug=debuglevel)
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/pycparser/ply/yacc.py", line 331, in parse
    return self.parseopt_notrack(input, lexer, debug, tracking, tokenfunc)
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/pycparser/ply/yacc.py", line 1199, in parseopt_notrack
    tok = call_errorfunc(self.errorfunc, errtoken, self)
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/pycparser/ply/yacc.py", line 193, in call_errorfunc
    r = errorfunc(token)
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/pycparser/c_parser.py", line 1761, in p_error
    column=self.clex.find_tok_column(p)))
  File "/Users/Natsume/miniconda2/envs/pytorch3.6/lib/python3.6/site-packages/pycparser/plyparser.py", line 66, in _parse_error
    raise ParseError("%s: %s" % (coord, msg))
pycparser.plyparser.ParseError: :37:13: before: *

During handling of the above exception, another exception occurred: