So I came up with this code:
at::Tensor fwht_forward(
at::Tensor input
) {
auto n_samples = input.size(0);
auto n_features = input.size(1);
auto G = n_features / 2;
auto M = 2;
at::Tensor temp = at::zeros({n_samples, G, 2});
for (auto i = 0; i < n_samples; i++) {
for (auto j = 0; j < n_features; j = j + 2) {
temp[i][j/2][0] = input[i][j] + input[i][j+1];
temp[i][j/2][1] = input[i][j] - input[i][j+1];
}
}
at::Tensor res = at::zeros({n_samples, G, 2});
res.copy_(temp);
for (auto i = 2; i < std::log2(n_features) + 1; i++) {
temp = at::zeros({n_samples, G / 2, M * 2});
auto res_acc = res.accessor<float, 3>();
for (auto j = 0; j < res.size(0); j++) {
for (auto k = 0; k < G; k = k + 2) {
for (auto l = 0; l < res.size(2); l = l + 4) {
temp[j][k/2][l] = res_acc[j][k][l/2] + res_acc[j][k+1][l/2];
temp[j][k/2][l+1] = res_acc[j][k][l/2] - res_acc[j][k+1][l/2];
temp[j][k/2][l+2] = res_acc[j][k][l/2+1] - res_acc[j][k+1][l/2+1];
temp[j][k/2][l+3] = res_acc[j][k][l/2+1] + res_acc[j][k+1][l/2+1];
}
}
}
res.copy_(temp);
G = G / 2;
M = M * 2;
}
auto temp_acc = temp.accessor<float, 3>();
at::Tensor output = at::zeros({n_samples, n_features});
for (auto m = 0; m < n_samples; m++) {
for (auto n = 0; n < temp.sizes()[2]; n++) {
output[m][n] = temp_acc[m][0][n];
}
}
return output * pow(std::sqrt(n_features), -1);
}
That compiles without errors. But now when I import it and try to use like
import torch
import fwht
fwht.forward(torch.randn(10, 1024))
I get the following error:
RuntimeError: copy_from does not support automatic differentiation; use copy_ instead (_s_copy_from at torch/csrc/autograd/generated/VariableType.cpp:459)
frame #0: at::CPUFloatType::s_copy_(at::Tensor&, at::Tensor const&, bool) const + 0x23b (0x7f009f07d3cb in /home/iacolippo/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #1: at::Type::copy_(at::Tensor&, at::Tensor const&, bool) const + 0x61 (0x7f009f1c5981 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #2: fwht_forward(at::Tensor) + 0x300 (0x7f009c4beed0 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/fwht-0.0.0-py3.6-linux-x86_64.egg/fwht.cpython-36m-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x85c0 (0x7f009c4c15c0 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/fwht-0.0.0-py3.6-linux-x86_64.egg/fwht.cpython-36m-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x100a5 (0x7f009c4c90a5 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/fwht-0.0.0-py3.6-linux-x86_64.egg/fwht.cpython-36m-x86_64-linux-gnu.so)
frame #5: _PyCFunction_FastCallDict + 0x154 (0x5602be0d0364 in /home/iacolippo/anaconda3/bin/python3.6)
frame #6: <unknown function> + 0x19eebc (0x5602be162ebc in /home/iacolippo/anaconda3/bin/python3.6)
frame #7: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #8: PyEval_EvalCodeEx + 0x329 (0x5602be15d8d9 in /home/iacolippo/anaconda3/bin/python3.6)
frame #9: PyEval_EvalCode + 0x1c (0x5602be15e67c in /home/iacolippo/anaconda3/bin/python3.6)
frame #10: <unknown function> + 0x1bdf2e (0x5602be181f2e in /home/iacolippo/anaconda3/bin/python3.6)
frame #11: _PyCFunction_FastCallDict + 0x91 (0x5602be0d02a1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #12: <unknown function> + 0x19eebc (0x5602be162ebc in /home/iacolippo/anaconda3/bin/python3.6)
frame #13: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #14: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #15: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #16: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #17: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #18: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #19: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #20: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #21: _PyEval_EvalFrameDefault + 0x10c7 (0x5602be1853e7 in /home/iacolippo/anaconda3/bin/python3.6)
frame #22: <unknown function> + 0x19828e (0x5602be15c28e in /home/iacolippo/anaconda3/bin/python3.6)
frame #23: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #24: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #25: _PyEval_EvalFrameDefault + 0x10c7 (0x5602be1853e7 in /home/iacolippo/anaconda3/bin/python3.6)
frame #26: <unknown function> + 0x198b8b (0x5602be15cb8b in /home/iacolippo/anaconda3/bin/python3.6)
frame #27: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #28: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #29: <unknown function> + 0x198b8b (0x5602be15cb8b in /home/iacolippo/anaconda3/bin/python3.6)
frame #30: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #31: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #32: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #33: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #34: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #35: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #36: <unknown function> + 0x198b8b (0x5602be15cb8b in /home/iacolippo/anaconda3/bin/python3.6)
frame #37: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #38: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #39: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #40: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #41: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #42: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #43: PyEval_EvalCodeEx + 0x329 (0x5602be15d8d9 in /home/iacolippo/anaconda3/bin/python3.6)
frame #44: PyEval_EvalCode + 0x1c (0x5602be15e67c in /home/iacolippo/anaconda3/bin/python3.6)
frame #45: <unknown function> + 0x214ce4 (0x5602be1d8ce4 in /home/iacolippo/anaconda3/bin/python3.6)
frame #46: PyRun_FileExFlags + 0xa1 (0x5602be1d90e1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #47: PyRun_SimpleFileExFlags + 0x1c4 (0x5602be1d92e4 in /home/iacolippo/anaconda3/bin/python3.6)
frame #48: Py_Main + 0x5ff (0x5602be1dcdaf in /home/iacolippo/anaconda3/bin/python3.6)
frame #49: main + 0xee (0x5602be0a38be in /home/iacolippo/anaconda3/bin/python3.6)
frame #50: __libc_start_main + 0xf1 (0x7f00bfae91c1 in /lib/x86_64-linux-gnu/libc.so.6)
frame #51: <unknown function> + 0x1c70da (0x5602be18b0da in /home/iacolippo/anaconda3/bin/python3.6)
So the culprit seems res.copy_(temp);, but I don’t understand how it should be modified. Any ideas?