How to improve conditional tensor operation?

I have tensor condition operation, such as

for i:
   if check(rate[i]):
      rate[i] = reset(rate[i])

right now, i implemented as following,
rate = torch.where(check(rate), reset(rate), rate)

It turns out that even for tensor where check(rate) is false, reset(rate) is still operated, kind of like

new_rate = reset(rate)
rate = torch.where(check(rate), new_rate, rate)

I am wondering anyway i could improve the performance here, if reset(rate) is really expensive.

It seems mask should help the case, like

mask = check(rate)
rate[mask] = reset(rate[mask])

And now only the tensor of the mask will be operated. However, it is even slower.

Here is an example,

import torch
import time
import copy

def reset(t):
   return 1.0 / torch.log(1.0 + t)

source1 = torch.FloatTensor(10000, 10000)
source2 = copy.deepcopy(source1)

start = time.time()
mask = source1 < 0.5
source1[mask] = reset(source1[mask])
print('test1 takes: ', time.time() - start)

start = time.time()
mask = source2 < 0.5
source2 = torch.where(mask, reset(source2), source2)
print('test2 takes: ', time.time() - start)

On my machine,
test1 takes: 3.4766557216644287
test2 takes: 1.7760250568389893