Tensor.apply_ funtion

Is there another way to Applies the function callable to each element in the tensor instead of using apply_, because apply_ is quite slow. and only works with CPU tensors…

2 Likes

torch.apply_ is slow, and we don’t have a great efficient way to apply an arbitrary function to a tensor, but a common workaround for simple operations can be to use a mask.

E.g. say you wanted to do something like tensor.apply_( lambda x: x + 2 if x > 5 else x ), instead you could write something like result = (tensor > 5) * 2 + tensor.

Another function that can sometimes help is torch.where.

3 Likes

ok. i will have a try. Thank you!