Can we save computation for some entries in a tensor?

Some operations are extremely costly, such as nn.BiLinear.

Can we have a mask for the input tensors where some of their entries get calculated and some does not?

For example,

input_a = input_b = torch.ones(A, B, ..., Z)
net     = nn.BiLinear(Z, Z, Z)
mask    = torch.rand(A, B, ..., Y, 1) > 0.9 # only a fraction needs to be calculated.
output  = net(input_a, input_b, mask)
# unmasked entries may yield 0 or any other numbers, it's okay.

output.sum().backward()
# (mask xor input_a.grad) will show the pattern

I can now only think of a wrapper function getting scatter and gather involved.

Thanks in advance!