Can we add shape check when rewrite the computation graph in torchScript?

I want to rewrite the computation graph, transform the origin op to custom op, but my custom op has tensor shape restriction, can we add shape check when rewrite?

Can you share a minimal reproducible example so people can debug your problem?

1 Like
pattern = """ 
    graph(%x,%1,%2,%3): 
        %0:Tensor=aten::index_select(%x,%3,%1) 
        %out:Tensor=aten::mul(%0,%2)
        return(%out) 
""" 
replacement = """ 
    graph(%x,%1:Tensor,%2,%3): 
        %out:Tensor=index_mul::index_mul(%x,%2,%1)
        return(%out) 
""" 
torch._C._jit_pass_custom_pattern_based_rewrite_graph(pattern, replacement,graph)

For operation index_mul::index_mul(%x,%2,%1), it wants %x to be 2 dim tensor, %2: 2 dim tensor, %1: 1 dim tensor, can we add shape check in subgraph rewrite?