Sia_Rezaei
(Sia Rezaei)
December 10, 2021, 8:37pm
1
Hi,
I would like to work on this issue:
opened 03:43AM - 03 Nov 21 UTC
triaged
enhancement
module: ux
module: viewing and reshaping
## 🚀 Feature
This is an operation that is almost the inverse of `where`.
Let'… s call it `sieve`.
It takes as input a tensor `A` and a boolean tensor `B` of the same shape and a "fill" value `c`.
It returns two tensors `O1` and `O2` of the same shape as `A`.
`O1` is equal to `A` where `B` is True and is equal `c` where `B` is False.
`O2` is equal to `A` where `B` is False and is equal `c` where `B` is True.
For example, let's say you want to separate positive and non-positive elements of `A` into two tensors.:
``` pos, neg = torch.sieve(A, B=(X > 0), fill=0) ```
## Motivation
This can be made twice faster than current solutions.
## Pitch
This obviously can already be done in various ways in PyTorch currently.
But it can be theoretically be twice as fast when done in one operation.
That is because **ALL** current ways of achieving this have to go over `A` twice.
Once to get `O1` and once to get `O2`.
But we know if something does not go to `O1`, is has to go to `O2`.
Therefore, we can do this in one pass of `A`.
## Alternatives
For the example mentioned above, one current solution is to do.
```
neg = torch.minimum(X, 0)
pos = torch.maximum(X, 0)
```
Here, we are comparing each element of X with zero twice, but we don't have to.
## Context
I use this in a custom activation function. So having this operation means my activation functions will be 2x faster, and that makes my runs noticeably faster.
I'm willing to dedicate some time to this, if someone is willing to guide me, as I haven't written custom CPU or Kernel code.
It is a simple function, but I think it needs to be implement at a very low level, since none of the primitives in the current API can address the issue.
Looking at the implementation of the where
and relu
functions should be helpful. But I can’t quite located their low-level implementations in the repo. It would be great if someone could point them out to me.
Any other pointers are also much appreciated.
Maybe start creating it as a pytorch extension?
https://pytorch.org/tutorials/advanced/cpp_extension.html
It doesn’t need to be written inside core pytorch in the first go, I guess.
Sia_Rezaei
(Sia Rezaei)
December 11, 2021, 1:42am
3
Can you point me to the implementation of where
and relu
functions?