Advanced indexing when writing an extension in ATen

Hi all,

I’m trying to write a function that computes the Fast Walsh Hadamard transform using ATen, at some point I have a few lines that make use of advanced indexing. I’ve only found this Github issue regarding advanced indexing in ATen: https://github.com/zdevito/ATen/issues/78

I have zero experience in this language and https://pytorch.org/cppdocs/ is not of much help for the moment.

My Python code looks like this:

temp = torch.zeros((N_samples, N // 2, 2), device=x.device)  # very important, have to
# initialize the new tensors on the used device
temp[:, :, 0] = x[:, 0::2] + x[:, 1::2]
temp[:, :, 1] = x[:, 0::2] - x[:, 1::2]
res = torch.tensor(temp, device=x.device)
# Second and further stage
for nStage in range(2, int(log(N, 2)) + 1):
    temp = torch.zeros((N_samples, G // 2, M * 2), device=x.device)
    temp[:, 0:G // 2, 0:M * 2:4] = res[:, 0:G:2, 0:M:2] + res[:, 1:G:2, 0:M:2]
    temp[:, 0:G // 2, 1:M * 2:4] = res[:, 0:G:2, 0:M:2] - res[:, 1:G:2, 0:M:2]
    temp[:, 0:G // 2, 2:M * 2:4] = res[:, 0:G:2, 1:M:2] - res[:, 1:G:2, 1:M:2]
    temp[:, 0:G // 2, 3:M * 2:4] = res[:, 0:G:2, 1:M:2] + res[:, 1:G:2, 1:M:2]
    res = torch.tensor(temp, device=x.device)
    G = G // 2
    M = M * 2

res = temp[:, 0, :]

How do I handle this kind of indexing in ATen?

Thanks in advance for your help

So I came up with this code:

at::Tensor fwht_forward(
    at::Tensor input
    ) {
    auto n_samples = input.size(0);
    auto n_features = input.size(1);
    auto G = n_features / 2;
    auto M = 2;

    at::Tensor temp = at::zeros({n_samples, G, 2});
    for (auto i = 0; i < n_samples; i++) {
        for (auto j = 0; j < n_features; j = j + 2) {
            temp[i][j/2][0] = input[i][j] + input[i][j+1];
            temp[i][j/2][1] = input[i][j] - input[i][j+1];
        }
    }
    at::Tensor res = at::zeros({n_samples, G, 2});
    res.copy_(temp);

    for (auto i = 2; i < std::log2(n_features) + 1; i++) {
        temp = at::zeros({n_samples, G / 2, M * 2});
        auto res_acc = res.accessor<float, 3>();
        for (auto j = 0; j < res.size(0); j++) {
            for (auto k = 0; k < G; k = k + 2) {
                for (auto l = 0; l < res.size(2); l = l + 4) {
                    temp[j][k/2][l] = res_acc[j][k][l/2] + res_acc[j][k+1][l/2];
                    temp[j][k/2][l+1] = res_acc[j][k][l/2] - res_acc[j][k+1][l/2];
                    temp[j][k/2][l+2] = res_acc[j][k][l/2+1] - res_acc[j][k+1][l/2+1];
                    temp[j][k/2][l+3] = res_acc[j][k][l/2+1] + res_acc[j][k+1][l/2+1];
                }
            }
        }
        res.copy_(temp);
        G = G / 2;
        M = M * 2;
    }

    auto temp_acc = temp.accessor<float, 3>();
    at::Tensor output = at::zeros({n_samples, n_features});
    for (auto m = 0; m < n_samples; m++) {
        for (auto n = 0; n < temp.sizes()[2]; n++) {
            output[m][n] = temp_acc[m][0][n];
        }
    }
    return output * pow(std::sqrt(n_features), -1);
}

That compiles without errors. But now when I import it and try to use like

import torch
import fwht

fwht.forward(torch.randn(10, 1024))

I get the following error:

RuntimeError: copy_from does not support automatic differentiation; use copy_ instead (_s_copy_from at torch/csrc/autograd/generated/VariableType.cpp:459)
frame #0: at::CPUFloatType::s_copy_(at::Tensor&, at::Tensor const&, bool) const + 0x23b (0x7f009f07d3cb in /home/iacolippo/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #1: at::Type::copy_(at::Tensor&, at::Tensor const&, bool) const + 0x61 (0x7f009f1c5981 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #2: fwht_forward(at::Tensor) + 0x300 (0x7f009c4beed0 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/fwht-0.0.0-py3.6-linux-x86_64.egg/fwht.cpython-36m-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x85c0 (0x7f009c4c15c0 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/fwht-0.0.0-py3.6-linux-x86_64.egg/fwht.cpython-36m-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x100a5 (0x7f009c4c90a5 in /home/iacolippo/anaconda3/lib/python3.6/site-packages/fwht-0.0.0-py3.6-linux-x86_64.egg/fwht.cpython-36m-x86_64-linux-gnu.so)
frame #5: _PyCFunction_FastCallDict + 0x154 (0x5602be0d0364 in /home/iacolippo/anaconda3/bin/python3.6)
frame #6: <unknown function> + 0x19eebc (0x5602be162ebc in /home/iacolippo/anaconda3/bin/python3.6)
frame #7: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #8: PyEval_EvalCodeEx + 0x329 (0x5602be15d8d9 in /home/iacolippo/anaconda3/bin/python3.6)
frame #9: PyEval_EvalCode + 0x1c (0x5602be15e67c in /home/iacolippo/anaconda3/bin/python3.6)
frame #10: <unknown function> + 0x1bdf2e (0x5602be181f2e in /home/iacolippo/anaconda3/bin/python3.6)
frame #11: _PyCFunction_FastCallDict + 0x91 (0x5602be0d02a1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #12: <unknown function> + 0x19eebc (0x5602be162ebc in /home/iacolippo/anaconda3/bin/python3.6)
frame #13: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #14: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #15: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #16: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #17: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #18: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #19: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #20: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #21: _PyEval_EvalFrameDefault + 0x10c7 (0x5602be1853e7 in /home/iacolippo/anaconda3/bin/python3.6)
frame #22: <unknown function> + 0x19828e (0x5602be15c28e in /home/iacolippo/anaconda3/bin/python3.6)
frame #23: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #24: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #25: _PyEval_EvalFrameDefault + 0x10c7 (0x5602be1853e7 in /home/iacolippo/anaconda3/bin/python3.6)
frame #26: <unknown function> + 0x198b8b (0x5602be15cb8b in /home/iacolippo/anaconda3/bin/python3.6)
frame #27: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #28: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #29: <unknown function> + 0x198b8b (0x5602be15cb8b in /home/iacolippo/anaconda3/bin/python3.6)
frame #30: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #31: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #32: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #33: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #34: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #35: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #36: <unknown function> + 0x198b8b (0x5602be15cb8b in /home/iacolippo/anaconda3/bin/python3.6)
frame #37: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #38: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #39: <unknown function> + 0x197f24 (0x5602be15bf24 in /home/iacolippo/anaconda3/bin/python3.6)
frame #40: <unknown function> + 0x198dc1 (0x5602be15cdc1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #41: <unknown function> + 0x19ef95 (0x5602be162f95 in /home/iacolippo/anaconda3/bin/python3.6)
frame #42: _PyEval_EvalFrameDefault + 0x30a (0x5602be18462a in /home/iacolippo/anaconda3/bin/python3.6)
frame #43: PyEval_EvalCodeEx + 0x329 (0x5602be15d8d9 in /home/iacolippo/anaconda3/bin/python3.6)
frame #44: PyEval_EvalCode + 0x1c (0x5602be15e67c in /home/iacolippo/anaconda3/bin/python3.6)
frame #45: <unknown function> + 0x214ce4 (0x5602be1d8ce4 in /home/iacolippo/anaconda3/bin/python3.6)
frame #46: PyRun_FileExFlags + 0xa1 (0x5602be1d90e1 in /home/iacolippo/anaconda3/bin/python3.6)
frame #47: PyRun_SimpleFileExFlags + 0x1c4 (0x5602be1d92e4 in /home/iacolippo/anaconda3/bin/python3.6)
frame #48: Py_Main + 0x5ff (0x5602be1dcdaf in /home/iacolippo/anaconda3/bin/python3.6)
frame #49: main + 0xee (0x5602be0a38be in /home/iacolippo/anaconda3/bin/python3.6)
frame #50: __libc_start_main + 0xf1 (0x7f00bfae91c1 in /lib/x86_64-linux-gnu/libc.so.6)
frame #51: <unknown function> + 0x1c70da (0x5602be18b0da in /home/iacolippo/anaconda3/bin/python3.6)

So the culprit seems res.copy_(temp);, but I don’t understand how it should be modified. Any ideas?

same thing if I do: at::Tensor res = temp;

So I’ve heard that the happy people use torch::Tensor, not at::Tensor. (What also works, and this is needed when you do stuff in ATen itself, is to grab some input’s t.options() and use that in the factory functions.)
I’ll admit, that it’s a bit hard to find in the C++ docs at the moment, but they’ll likely soon be awesome.

Also, you don’t want to use indexing this way with tensors. As these seem to be pointwise ops, use TensorAccessors (auto t_a = temp.accessor<float, 3>()).
Note that you need to do type dispatching for this, if you want it to work with various dtypes. You’ll find patterns for this in the ATen/native subdirectory of the source code (e.g. I implemented the CPU CTC loss using accessors).

Best regards

Thomas

1 Like

Ahahahahah awesome answer, I’ll try to understand these things and probably come back with more questions :smiley:

Thank you!

For your original problem: The function you were looking for is probably slice. For non-advanced, there also is narrow.

Best regards

Thomas

Sorry for the probably stupid questions but I’m a total newbie here:

  • if I try to use torch::Tensor I get
fwht.cpp:7:8: error: ‘Tensor’ in namespace ‘torch’ does not name a type

Do I get to use this torch::Tensor only if I write my extension in the ATen native folder inside the pytorch repo and then recompile?

  • if I use at::Tensor everything compiles correctly but when I try to use my Python binding fwht.forward(torch.randn(10, 512)) I get
RuntimeError: copy_from does not support automatic differentiation; use copy_ instead (_s_copy_from at /home/iacolippo/pytorch/torch/csrc/autograd/VariableTypeManual.cpp:265)

From the README and the guide on how to write C++ extensions I understood that it’s not really a problem if autodiff breaks provided that you define a backward function, which I’m doing here.

#include <torch/extension.h>
#include<cmath>
#include <iostream>
#include <vector>


at::Tensor fwht_forward(
    at::Tensor input
    ) {
    auto n_samples = input.size(0);
    auto n_features = input.size(1);
    auto G = n_features / 2;
    auto M = 2;

    at::Tensor temp = at::zeros({n_samples, G, 2});
    for (auto i = 0; i < n_samples; i++) {
        for (auto j = 0; j < n_features; j = j + 2) {
            temp[i][j/2][0] = input[i][j] + input[i][j+1];
            temp[i][j/2][1] = input[i][j] - input[i][j+1];
        }
    }
    at::Tensor res = at::empty({n_samples, G, 2});
    res.copy_(temp);

    for (auto i = 2; i < std::log2(n_features) + 1; i++) {
        temp = at::zeros({n_samples, G / 2, M * 2});
        auto res_acc = res.accessor<float, 3>();
        auto t_a = temp.accessor<float, 3>();
        for (auto j = 0; j < res.size(0); j++) {
            for (auto k = 0; k < G; k = k + 2) {
                for (auto l = 0; l < res.size(2); l = l + 4) {
                    t_a[j][k/2][l] = res_acc[j][k][l/2] + res_acc[j][k+1][l/2];
                    t_a[j][k/2][l+1] = res_acc[j][k][l/2] - res_acc[j][k+1][l/2];
                    t_a[j][k/2][l+2] = res_acc[j][k][l/2+1] - res_acc[j][k+1][l/2+1];
                    t_a[j][k/2][l+3] = res_acc[j][k][l/2+1] + res_acc[j][k+1][l/2+1];
                }
            }
        }
        res.copy_(temp);
        G = G / 2;
        M = M * 2;
    }

    auto temp_acc = temp.accessor<float, 3>();
    at::Tensor output = at::zeros({n_samples, n_features});
    for (auto m = 0; m < n_samples; m++) {
        for (auto n = 0; n < temp.sizes()[2]; n++) {
            output[m][n] = temp_acc[m][0][n];
        }
    }
    return output * pow(std::sqrt(n_features), -1);
}


at::Tensor fwht_backward(
    at::Tensor input
    ) {
    auto n_samples = input.size(0);
    auto n_features = input.size(1);
    auto G = n_features / 2;
    auto M = 2;

    at::Tensor temp = at::zeros({n_samples, G, 2});
    for (auto i = 0; i < n_samples; i++) {
        for (auto j = 0; j < n_features; j = j + 2) {
            temp[i][j/2][0] = input[i][j] + input[i][j+1];
            temp[i][j/2][1] = input[i][j] - input[i][j+1];
        }
    }
    at::Tensor res = at::empty({n_samples, G, 2});
    res.copy_(temp);

    for (auto i = 2; i < std::log2(n_features) + 1; i++) {
        temp = at::zeros({n_samples, G / 2, M * 2});
        auto res_acc = res.accessor<float, 3>();
        auto t_a = temp.accessor<float, 3>();
        for (auto j = 0; j < res.size(0); j++) {
            for (auto k = 0; k < G; k = k + 2) {
                for (auto l = 0; l < res.size(2); l = l + 4) {
                    t_a[j][k/2][l] = res_acc[j][k][l/2] + res_acc[j][k+1][l/2];
                    t_a[j][k/2][l+1] = res_acc[j][k][l/2] - res_acc[j][k+1][l/2];
                    t_a[j][k/2][l+2] = res_acc[j][k][l/2+1] - res_acc[j][k+1][l/2+1];
                    t_a[j][k/2][l+3] = res_acc[j][k][l/2+1] + res_acc[j][k+1][l/2+1];
                }
            }
        }
        res.copy_(temp);
        G = G / 2;
        M = M * 2;
    }

    auto temp_acc = temp.accessor<float, 3>();
    at::Tensor output = at::zeros({n_samples, n_features});
    for (auto m = 0; m < n_samples; m++) {
        for (auto n = 0; n < temp.sizes()[2]; n++) {
            output[m][n] = temp_acc[m][0][n];
        }
    }
    return output * pow(std::sqrt(n_features), -1);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &fwht_forward, "FWHT forward");
  m.def("backward", &fwht_backward, "FWHT backward");
}

what’s the issue with this copy_?