Hello everyone,
I’m porting a python module into c++ and I’m stuck here as how to efficiently port :
np.where(np_array > some_float)
numpy returns two tuples of indexes. I found out torchlib has a where
method as well but it returns a at::autograd::variable_list
which I guess is simply a std::vector<torch::Variable>
.
I couldnt find any related documentation as how to go about it and whether this equivalent to the np.where
!
Running a simple test, shows that the return value of the torch::where
seems to be having two tensors as the result. and each seem to represent a set of indexes.
However, Its not clear to me, how to interpret these numbers, since there are duplicate numbers and I cant make anything out of it!
Here is the simple demo :
#include <iostream>
#include <torch/torch.h>
#include <torch/script.h>
int main()
{
torch::Tensor preds = torch::randn({ 4,8 });
torch::Tensor offsets = torch::randn({ 1, 4, 4, 8 });
auto res = torch::where(preds > 0.9);
std::cout << "res0.size(): " << res[0].sizes()[0] << std::endl;
if (res[0].sizes()[0] == 0)
{
std::cout << "nothing is returned!" << std::endl;
}
std::cout << "preds:\n" << preds << std::endl;
std::cout << "offsets:\n" << offsets << std::endl;
std::cout << "res.size: " << res.size() << std::endl;
std::cout << "res0:\n" << res[0] << std::endl;
std::cout << "res1:\n" << res[1] << std::endl;
std::cout << "Hello wordl!" << std::endl;
std::system("pause");
}
this produces this output:
res0.size(): 9
preds:
-0.9659 -0.9782 -1.7985 -1.5461 -0.2913 0.2766 -1.6117 -0.4642
-1.8279 0.9883 -1.0430 -0.0729 1.4847 -0.8073 -1.8110 0.1578
0.6299 0.6409 1.0483 0.9996 -0.5102 -0.5886 0.9410 0.3633
-0.4674 1.4806 1.0258 -0.5636 0.9242 -0.1481 -1.8079 2.4731
[ CPUFloatType{4,8} ]
offsets:
(1,1,.,.) =
-0.1614 0.3842 -0.8396 -0.0390 -0.5792 0.1151 -1.4113 0.4774
0.1490 1.7119 -0.3051 -0.1585 -1.2240 -0.7920 -0.7085 -0.1590
-0.6625 1.6923 -0.0494 -0.5374 -0.7239 -0.6688 -1.7225 0.3280
0.8866 2.4275 -0.4948 1.6180 -0.4133 -0.5967 0.3257 -0.7997
(1,2,.,.) =
-0.5249 1.8908 -0.5924 1.3851 0.8556 0.4914 0.6360 0.8373
0.2659 -1.3615 0.1923 -0.9496 -1.8631 -0.1948 -0.3652 0.0215
-0.2501 2.2011 0.8395 -0.4594 -0.2585 0.1908 0.4444 0.2588
0.5341 -1.6821 0.3597 -0.0394 -0.0743 0.5159 -0.9592 1.0634
(1,3,.,.) =
-1.3950 -0.7348 1.1979 1.0431 0.4164 -0.1143 0.0959 0.1557
2.0382 0.0820 0.2825 0.5211 2.0970 0.2202 -1.4359 -1.0877
0.1951 0.4328 -0.2317 -0.7700 -0.4143 -1.8647 -0.6055 0.8088
0.7052 -0.1012 0.3797 -1.1983 -1.0531 0.6381 -0.6158 -0.6769
(1,4,.,.) =
-1.0985 -1.0719 -0.3173 -0.2694 1.4531 -0.7597 0.4511 0.5871
-1.2546 1.6486 -1.1799 0.7037 -0.6020 -0.8686 1.3177 -1.1100
0.4443 -0.7602 -1.3926 -0.0029 -0.9942 1.0145 -1.4573 -0.0027
0.5276 0.5788 0.7841 -0.1466 0.3800 -0.0146 -0.1285 -2.1790
[ CPUFloatType{1,4,4,8} ]
res.size: 2
res0:
1
1
2
2
2
3
3
3
3
[ CPULongType{9} ]
res1:
1
4
2
3
6
1
2
4
7
[ CPULongType{9} ]
Hello wordl!
So any help is greatly appreciated