Some operations are extremely costly, such as nn.BiLinear.

Can we have a mask for the input tensors where some of their entries get calculated and some does not?

For example,

```
input_a = input_b = torch.ones(A, B, ..., Z)
net = nn.BiLinear(Z, Z, Z)
mask = torch.rand(A, B, ..., Y, 1) > 0.9 # only a fraction needs to be calculated.
output = net(input_a, input_b, mask)
# unmasked entries may yield 0 or any other numbers, it's okay.
output.sum().backward()
# (mask xor input_a.grad) will show the pattern
```

I can now only think of a wrapper function getting scatter and gather involved.

Thanks in advance!