Some operations are extremely costly, such as nn.BiLinear.
Can we have a mask for the input tensors where some of their entries get calculated and some does not?
For example,
input_a = input_b = torch.ones(A, B, ..., Z)
net = nn.BiLinear(Z, Z, Z)
mask = torch.rand(A, B, ..., Y, 1) > 0.9 # only a fraction needs to be calculated.
output = net(input_a, input_b, mask)
# unmasked entries may yield 0 or any other numbers, it's okay.
output.sum().backward()
# (mask xor input_a.grad) will show the pattern
I can now only think of a wrapper function getting scatter and gather involved.
Thanks in advance!