I’m pretty sure it’s about my device, probably the cpu. Everything else seems fine, but one day torch.where
became really slow.
import torch
print(torch.__version__)
import time
t = torch.randint(0,100,(1000,1000))
time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
the result is:
1.6.0
torch.where time: 0.041625261306762695
but in my own MacBook Pro, it is much faster:
>>> import torch
... print(torch.__version__)
... import time
... t = torch.randint(0,100,(1000,1000))
... time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
...
1.6.0
torch.where time: 0.0031599998474121094
Also to mention that the time may vary, but it is always longer than 0.01s:
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.014759302139282227
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.06390643119812012
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.019058942794799805
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.041625261306762695
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.03616046905517578
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.0942380428314209
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.09639406204223633
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.09022760391235352
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.019794464111328125
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.08239603042602539
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.015592098236083984
>>> time_ = time.time();torch.where(t == 0);print('torch.where time:', time.time() - time_)
(tensor([ 0, 0, 0, ..., 999, 999, 999]), tensor([ 65, 254, 354, ..., 937, 971, 984]))
torch.where time: 0.024471759796142578
Any idea about where I should check for the device of the server I’m using? Many thanks.