Incompatible function arguments with c++ extension

I use a function written with c++ extension, but I get the following error:

incompatible function arguments. The following argument types are supported:
    1. (arg0: int, arg1: int, arg2: int, arg3: float, arg4: int, arg5: at::Tensor, arg6: at::Tensor, arg7: at::Tensor) -> int

Invoked with: 1, 1024, 512, 0.23, 48, tensor([[[ 0.6824,  0.1034,  0.4344],
         [-0.5448, -0.3518, -0.3540],
         [-0.2593, -0.3775,  0.7748],
         ...,
         [ 0.0176, -0.4164,  0.7543],
         [ 0.2757,  0.1917,  0.4559],
         [ 0.4798,  0.3335,  0.5477]]], device='cuda:0'), tensor([[[ 0.6824,  0.1034,  0.4344],
         [ 0.4867,  0.0260, -0.3418],
         [ 0.1544, -0.1811,  0.7433],
         ...,
         [-0.3573, -0.0785,  0.5690],
         [-0.3930, -0.4164, -0.0296],
         [ 0.0472,  0.1214,  0.6922]]], device='cuda:0'), tensor([[   0,  357,  548,  211,  799,  842,  961,  639,   88,  701,   44,  242,
          309,  971,  388,  219,  327,  849,  528,  130,  164,  555,  694,  345,
          858,  731,  746,  889,  222,  888,  875,  905,  507,  788,  529,  151,
          679,  199,  661,  866,  666,  458,  778,  972,  945,   95,  743,  699,
          411,  672,  217,  105,  516,  465,  286,  761,  344,  791,  485,  851,
          932,  104,  241,  863,  269,  149,  489,  206,  333,  415,  595,  596,
          179,  137,  195,    4,  687,   77,  967,  614,  160,  426,  273,  370,
          207,  444,  441,  649,  925,  892,  969,  659,  899,   82,  418,   20,
          921,  102,  225,  953,  202,  974,  700,  324,  540,  318,  587,  615,
          936,  112,  414,  522,  514,  311,  322,  808,  856,  845,  637,  409,
          451,  245,  626,  554,  303,   49,  994,  810,  789,   30,  397,  532,
           43,  234,  376,   57,  954,  430,  784,  157,  848,  816,  302,  267,
          174,  682,  178,  798,  200,  135,  638,  917,  513, 1000,  521,  271,
           51,  757,  644,  138,  348,  907,  751,  243,  774,  190,  655,  139,
          482,  350,  161,   39,  860,  904,  926,  819,  738,  113,  400,  188,
          753,  633,  713,  975,  780,  226,  563,  117,  479,   99,  192,   26,
          279,  603,   11,  685,   10,  552,  320,  871,  606,  154,   18,  567,
          281,  212,  325,  624,  452,  796,  152,   59,  933,  634,  748,  240,
          903,  868,  537,  591,  627,  978,  837,  464,  470,  836,  391,  825,
          914,  406,   41,  237,  275,   23,   45,  148,  502,  163,   90,  146,
         1002,  551,  432,  287,  878,  902,  256,  448,  337,  159,  367,  833,
           52,  949,  812,  114,  423,  776,  501,  227, 1021,  523,  960,  956,
          959,  312,  768,   79,  814,  997,  873,  846,  913,  496,  840,  326,
          741,  463,   53,  640,  723,  277,  622,  771,  966,  447,   65,  189,
          306,   98,  307,  472, 1016,  475,  775,  238,  420,  220,  354,  821,
          511,  332,  707,  165,  657, 1018,  150,  923,  124,  918,  115,  210,
          436,  992,   46,  895,  498,  946,  182,  719,  407,   61,  632,   17,
          664,  284,  460,  171,  175,  886,  565,  876,  197,  296,  754,  728,
          853,  843,  254,  143,  468,  571,  583,  991,  795,  942,  449,  525,
          805,  980,  239,  544,  562,  720,   33,  983,   13,  712,  869,  696,
          253,  770,  266,  909,   47,  471,  144,  377,  262,  316,  968,   70,
          556,  535,  515,  937,  601,  387,  721,  569,  539,  366,  362,  204,
          928, 1005,  368, 1009,  820,  930, 1023,   76,  870,  499,  520,  509,
          568,  665,  512,  382,  480,  605,  702,  578,  301,  172,  635,  584,
          957,  289,  223,   74,  131,  349, 1006,  106,  169,  435,  847,  125,
          725,  958,  735,  759,   78,   25,  984,  906,  744,  371,  116,  393,
          688,  727,  205,  487, 1017,  897,  246,  575,  572,  752,   58,  467,
          298,  885,  469,  920,  293,  428,  590,  979,  793,  802,  887,  492,
          695,  359,   55,  103,  260,   62,  681,  558,  964,    7,  559,  173,
          317,  841,  686,  500,  803,  128,  553,  844,  763,  291,  910,  618,
          608,  101,  804,  602,  790,  809,  410,   64,  405,  648,  209,  714,
          132,  380,  228,  924,  651,  299,  419,  645, 1012,  986,   66,  939,
          546,  249,  993,  834,  709,  880,   37,  669]], device='cuda:0',
       dtype=torch.int32), tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], device='cuda:0', dtype=torch.int32)

I thought I used the right type of arguments, is the difference between at::Tensor and torch::Tensor the cause to this error?

1 Like

I actually get the same issue. Following up on this^

OK. I think I know what’s going on. You need to include the header when you wrap your function with pybind11

#include <torch/extension.h>
2 Likes

Please am having the same issue, has anyone been able to find the solution
#include <torch/extension.h>
i already have it included but still same error