I have tensor condition operation, such as
for i:
if check(rate[i]):
rate[i] = reset(rate[i])
right now, i implemented as following,
rate = torch.where(check(rate), reset(rate), rate)
It turns out that even for tensor where check(rate) is false, reset(rate) is still operated, kind of like
new_rate = reset(rate)
rate = torch.where(check(rate), new_rate, rate)
I am wondering anyway i could improve the performance here, if reset(rate) is really expensive.
It seems mask should help the case, like
mask = check(rate)
rate[mask] = reset(rate[mask])
And now only the tensor of the mask will be operated. However, it is even slower.
Here is an example,
import torch
import time
import copy
def reset(t):
return 1.0 / torch.log(1.0 + t)
source1 = torch.FloatTensor(10000, 10000)
source1.uniform_()
source2 = copy.deepcopy(source1)
start = time.time()
mask = source1 < 0.5
source1[mask] = reset(source1[mask])
print('test1 takes: ', time.time() - start)
start = time.time()
mask = source2 < 0.5
source2 = torch.where(mask, reset(source2), source2)
print('test2 takes: ', time.time() - start)
On my machine,
test1 takes: 3.4766557216644287
test2 takes: 1.7760250568389893