Hi guys,
I wanna convert tensor with size (4,3,2) to list of tensor (3,2) and run it on mobile.
Here is code creating torchscript:
class BatchAsList(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
x = torch.arange(4*3*2).view(4,3,2)
a = []
for i in range(x.size(0)):
a.append(x[i])
return a
module = BatchAsList()
script = torch.jit.script(module)
script.save('batch_as_list.pth')
And code run on android:
// mTestModule is loaded from file batch_as_list.pth
Tensor [] output = mTestModule.forward().toTensorList();
for (int i = 0 ; i< output.length; i++){
Log.i(TAG, "Output " + i + " shape : " + output[i]);
}
long [] output0 = output[0].getDataAsLongArray();
long [] output1 = output[1].getDataAsLongArray();
long [] output2 = output[2].getDataAsLongArray();
long [] output3 = output[3].getDataAsLongArray();
for (int j = 0 ; j< output0.length; j++){
Log.i(TAG, "Output : " + output0[j] +" ,"+ output1[j] +" ,"+ output2[j] + ", "+ output3[j] );
}
And result:
2019-12-12 10:35:10.352 18859-18859/org.pytorch.demo D/TEST: /data/user/0/org.pytorch.demo/files/batch_as_list.pth
2019-12-12 10:35:10.381 18859-18859/org.pytorch.demo I/TEST: Output 0 shape : Tensor([3, 2], dtype=torch.int64)
2019-12-12 10:35:10.381 18859-18859/org.pytorch.demo I/TEST: Output 1 shape : Tensor([3, 2], dtype=torch.int64)
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output 2 shape : Tensor([3, 2], dtype=torch.int64)
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output 3 shape : Tensor([3, 2], dtype=torch.int64)
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output : 0 ,0 ,0, 0
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output : 1 ,1 ,1, 1
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output : 2 ,2 ,2, 2
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output : 3 ,3 ,3, 3
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output : 4 ,4 ,4, 4
2019-12-12 10:35:10.382 18859-18859/org.pytorch.demo I/TEST: Output : 5 ,5 ,5, 5
Shape of output is correct, but something wrong with the value of output. Output that i expected:
0, 6, 12, 18
1, 7, 13, 19
2, 8, 14, 20
3, 9, 15, 21
4, 10, 16, 22
5, 11, 17, 23