Guiding the JITer to properly fusing non-contiguous map-reduce operations

Example case:

ONES = ones(1).to(DEVICE)
@jit.script
def foo(feats, mask):
    # mask : [batch features]
    # feats : [features features]
    return where( # ((1))
        mask,
        feats.unsqueeze(0),
        # shape: [_ features features]
        ONES
    # shape before prod: [batch features features]
    ).prod(1) # ((2))
    # shape after prod: [batch features]

((1)) and ((2)) should be fusible by directly storing the reduction in the out tensor while iterating over the mask, the data, and the reduction operation.

This… would probably be hell for the vectorizer. But at this point I’d take the hit in exchange for being able to run the thing in memory.

BUT… if any guru around here can show me the logical way I’ve been missing for doing this memory efficiently with existing primitives, please do! :smiley: