We’ve been working on a tool
tsanley to enable finding subtle shape errors in your deep learning code quickly and cheaply. The key idea is to label tensor variables with their expected shapes (e.g.,
x : 'b,t,d' = ...) using optional types in Python 3.x and let
tsanley perform shape validity checks at runtime automatically. Works with small and big tensor programs.
examples: models (Resnet, GraphNNs, Transformers)
def foo(x): x: 'b,t,d' #expected shape of x is (B, T, D). y: 'b,d' = x.mean(dim=0) * 2 # error! z: 'b,d' = x.mean(dim=1) # ok return y, z
foo contains tensor variables labeled with their named shapes using a shorthand notation. It has a subtle shape error in the assignment to
y: we expect the shape of
y to be
mean got rid of the first, and not the second, dimension.
pytorch won’t flag this as an error: instead, we will get a weird shape inconsistency error somewhere downstream.
tsanley finds such unexpected bugs quickly at runtime:
Update at line 37: actual shape of y = t,d >> FAILED shape check at line 37 expected: (b:10, d:1024), actual: (100, 1024) Update at line 38: actual shape of z = b,d >> shape check succeeded at line 38
Writing these named shape annotations manually can also get tedious.
tsanley can auto-annotate the tensor variables in your (or someone else’s) code, if the code is executable. This is especially useful when trying to dig deep into or adapt an existing code / library for your project.
The tool builds upon the tsalib library, which introduced a shorthand notation for labeling tensor variables with their named shapes, irrespective of the backend tensor library used.
We would love feedback on tsanley and hope it is useful for your coding/debugging workflow.