Indexing using the C++ APIs


(Hei Law) #1
#!/usr/bin/env python

import torch

def main():
    data = torch.FloatTensor(4, 64, 64).zero_().cuda()
    locs = torch.LongTensor([[2, 2, 4, 4]]).cuda()

    for x0, y0, x1, y1 in locs:
        print('x0: {}'.format(type(x0)))
        print('data: {}'.format(data[:, x0:x1, y0:y1]))

if __name__ == "__main__":
    main()

I want to write an C++ extension for the above code. I checked the C++ APIs in aten/src/ATen/core/Tensor.h. I found this API
Tensor slice(int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) const;
which allows me to do slicing. But the problem is it only takes int64_t as input but x0, y0, x1 and y1 are CUDA tensors. I just wonder how PyTorch implements the Python version of slicing (i.e. data[:, x0:x1, y0:y1]). I tried searching the source code but I couldn’t that.

How should I implement the C++ extension for the above code?