We’ve been working on a tool tsanley
to enable finding subtle shape errors in your deep learning code quickly and cheaply. The key idea is to label tensor variables with their expected shapes (e.g., x : 'b,t,d' = ...
) using optional types in Python 3.x and let tsanley
perform shape validity checks at runtime automatically. Works with small and big tensor programs.
repository: https://github.com/ofnote/tsanley
examples: models (Resnet, GraphNNs, Transformers)
Quick example:
def foo(x):
x: 'b,t,d' #expected shape of x is (B, T, D).
y: 'b,d' = x.mean(dim=0) * 2 # error!
z: 'b,d' = x.mean(dim=1) # ok
return y, z
Function foo
contains tensor variables labeled with their named shapes using a shorthand notation. It has a subtle shape error in the assignment to y
: we expect the shape of y
to be (B,D)
, however mean
got rid of the first, and not the second, dimension. pytorch
won’t flag this as an error: instead, we will get a weird shape inconsistency error somewhere downstream.
tsanley
finds such unexpected bugs quickly at runtime:
Update at line 37: actual shape of y = t,d
>> FAILED shape check at line 37
expected: (b:10, d:1024), actual: (100, 1024)
Update at line 38: actual shape of z = b,d
>> shape check succeeded at line 38
Writing these named shape annotations manually can also get tedious. tsanley
can auto-annotate the tensor variables in your (or someone else’s) code, if the code is executable. This is especially useful when trying to dig deep into or adapt an existing code / library for your project.
The tool builds upon the tsalib library, which introduced a shorthand notation for labeling tensor variables with their named shapes, irrespective of the backend tensor library used.
We would love feedback on tsanley and hope it is useful for your coding/debugging workflow.