Canonical way to assert tensor shape

What is the canonical way to assert that a given tensor has the correct shape, i.e. if it is known beforehand what shape it should have? Currently, I use assertions in the following way, which adds a lot of clutter to the code:

assert x.shape == torch.Size([dim1, dim2])

A similar question about IDE based tensor-shape checking has been asked here, but has not received an answer.

assert x.shape == (dim1, dim2) works aswell :slight_smile:

2 Likes

Thanks, that looks a lot nicer already.

1 Like

This doesn’t seem to work for 1 dimensional tensors. Any pro tips for something that works regardless of the number of dimensions?

a = torch.zeros([3])
b = torch.zeros([3, 5])
print(a.shape)
print(b.shape)
assert a.shape == (3) # error
assert b.shape == (3,5) # ok

The assertion fails because python will turn (3) into simply 3. The shape is a tuple, which means you need the comma in there: assert a.shape == (3,) # works.